mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
AMD/ROCm OCP Micro-scaling Format (mx-fp8/mx-fp4) Support (#151360)
- This pull request introduces support for the [OCP Micro-scaling (MX) format](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf), with a focus on compatibility with AMD **ROCm 7.0** and the **gfx950** architecture. This PR also establishes the foundation for enabling MX-FPX features in [TorchAO](https://github.com/pytorch/ao/issues/2229) on the AMD platform. - Validation (**ROCm 7.0** + **gfx950** required): `111 relevant tests passing.` > PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Co-author: @jagadish-amd — Thank you for the efforts leading validation on gfx950 with ROCm 7.0. ----------------------------------- This pull request introduces support for new scalar types and scaling methods, particularly for ROCm 7.0 and gfx950, and refines testing for these features. Key changes include adding constraints for matrix dimensions, enabling block-wise scaling, and updating tests to accommodate new data types. ### Support for new scalar types and scaling methods: * [`aten/src/ATen/cuda/CUDABlas.cpp`](diffhunk://#diff-74fcb26047c1df4024105d36ce22a36b77cf8cc93c28631d743e639b3d6066aeR1876-R1885): Added constraints for matrix dimensions when using `Float8_e8m0fnu` with block-wise scaling, ensuring dimensions are multiples of 32. Updated compatibility checks to support ROCm 7.0 for `Float8_e8m0fnu` and `Float8_e4m3fn`. [[1]](diffhunk://#diff-74fcb26047c1df4024105d36ce22a36b77cf8cc93c28631d743e639b3d6066aeR1876-R1885) [[2]](diffhunk://#diff-74fcb26047c1df4024105d36ce22a36b77cf8cc93c28631d743e639b3d6066aeL1913-R1934) * [`aten/src/ATen/native/cuda/Blas.cpp`](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abR1276-R1290): Introduced block-wise scaling for `Float8_e8m0fnu`, with checks for ROCm 7.0 and GPU architecture `gfx950`. Added validation for supported scalar types and matrix dimensions. [[1]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abR1276-R1290) [[2]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abR1349-R1364) ### Updates to scalar type mappings: * [`aten/src/ATen/cuda/CUDADataType.h`](diffhunk://#diff-9188bb13b1a49f459141f5f9b875593d1c5ce2beb5ad711fdbaf5bc7089ec015L93-R93): Extended scalar type mappings to support `Float4_e2m1fn_x2` for ROCm 7.0. * [`aten/src/ATen/cuda/tunable/GemmHipblaslt.h`](diffhunk://#diff-bfa1a3b5d4bef1892bf50338775f3b0fd8cd31fc1868148f3968b98aefb68e3fR88-R96): Added a constexpr mapping for `Float4_e2m1fn_x2` based on ROCm version. ### Enhancements to testing(@jagadish-amd): * [`test/test_matmul_cuda.py`](diffhunk://#diff-3f31c52b48cfddf8f4617d809f7695b2e4a1c78656f8c4b5143a4b45d01fcf23R765-R766): Updated tests to include new scalar types (`Float4_e2m1fn_x2`) and recipes (`mxfp4`). Added logic to handle different scaling recipes and validate compatibility with ROCm and CUDA versions. [[1]](diffhunk://#diff-3f31c52b48cfddf8f4617d809f7695b2e4a1c78656f8c4b5143a4b45d01fcf23R765-R766) [[2]](diffhunk://#diff-3f31c52b48cfddf8f4617d809f7695b2e4a1c78656f8c4b5143a4b45d01fcf23L1331-R1356) F592e669L1353R1472) These changes improve compatibility with newer hardware and software versions, enhance functionality for matrix operations, and ensure robust testing for the added features. Pull Request resolved: https://github.com/pytorch/pytorch/pull/151360 Approved by: https://github.com/drisspg, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
f2be3dc8da
commit
e389a08dcd
@ -1847,8 +1847,12 @@ int get_scale_mode(ScalingType scaling_type, ScalarType scale_dtype, bool use_fa
|
||||
switch (scaling_type) {
|
||||
case ScalingType::BlockWise1x32:
|
||||
TORCH_CHECK(scale_dtype == kFloat8_e8m0fnu);
|
||||
#if CUDA_VERSION >= 12080
|
||||
#if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && ROCM_VERSION >= 70000)
|
||||
#ifdef USE_ROCM
|
||||
return HIPBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
|
||||
#else
|
||||
return CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
|
||||
#endif // USE_ROCM
|
||||
#else
|
||||
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales of 1x32 blocks is only supported for CUDA 12.8 and above");
|
||||
#endif // if CUDA_VERSION >= 12080
|
||||
@ -1946,12 +1950,26 @@ void scaled_gemm(
|
||||
// hipblaslt supported row-wise before cublas, and did so their own way (via
|
||||
// the SCALE_POINTERSs), but then migrated to match how cublas does it (via
|
||||
// the SCALE_MODEs). Here we check for this early custom mode.
|
||||
bool use_rowwise = (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise);
|
||||
#if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
|
||||
if (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise) {
|
||||
if (use_rowwise) {
|
||||
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
|
||||
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
|
||||
}
|
||||
#endif // if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
|
||||
else if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) {
|
||||
#if ROCM_VERSION >= 70000
|
||||
if (at::detail::getCUDAHooks().isGPUArch({"gfx950"})) {
|
||||
// TODO: add constraints based on hipblaslt internals
|
||||
TORCH_CHECK((m % 32 == 0) && (n % 32 == 0) && (k % 32 == 0),
|
||||
"Matrix dimensions must be multiples of 32 for MX format. "
|
||||
"Got m=", m, ", n=", n, ", k=", k);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
// rowwise isn't supported using cublaslt or older hipblaslt
|
||||
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
|
||||
#endif // if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
|
||||
computeDesc.setAttribute(matmulDescA, mat1_scale_ptr);
|
||||
computeDesc.setAttribute(matmulDescB, mat2_scale_ptr);
|
||||
if (result_scale_ptr != nullptr) {
|
||||
@ -1990,15 +2008,16 @@ void scaled_gemm(
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
|
||||
}
|
||||
|
||||
// The SCALE_MODE attrs only exist in cuBLAS 12.8+ or in recent hipblaslt,
|
||||
// but we must invoke get_scale_mode anyways to trigger the version checks.
|
||||
[[maybe_unused]] int a_scale_mode = get_scale_mode(mat1_scaling_type, mat1_scale_dtype, use_fast_accum);
|
||||
[[maybe_unused]] int b_scale_mode = get_scale_mode(mat2_scaling_type, mat2_scale_dtype, use_fast_accum);
|
||||
#if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC))
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, a_scale_mode);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, b_scale_mode);
|
||||
#endif
|
||||
// For other data types, use the get_scale_mode function based on scaling type
|
||||
// The SCALE_MODE attrs only exist in cuBLAS 12.8+/ROCm 7.0 or in recent hipblaslt,
|
||||
// but we must invoke get_scale_mode anyways to trigger the version checks.
|
||||
// Note that AMD/ROCm follows OCP Spec 1.0, which is different from NVIDIA's implementation. See get_scale_mode() for details.
|
||||
[[maybe_unused]] int a_scale_mode = get_scale_mode(mat1_scaling_type, mat1_scale_dtype, use_fast_accum);
|
||||
[[maybe_unused]] int b_scale_mode = get_scale_mode(mat2_scaling_type, mat2_scale_dtype, use_fast_accum);
|
||||
#if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && ROCM_VERSION >= 70000 && defined(HIPBLASLT_OUTER_VEC))
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, a_scale_mode);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, b_scale_mode);
|
||||
#endif // if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && ROCM_VERSION >= 70000 && defined(HIPBLASLT_OUTER_VEC))
|
||||
|
||||
CuBlasLtMatmulPreference preference;
|
||||
auto ltworkspace = CublasLtWorkspace();
|
||||
|
@ -90,7 +90,7 @@ inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type)
|
||||
case c10::ScalarType::Float8_e5m2fnuz:
|
||||
return HIP_R_8F_E5M2_FNUZ;
|
||||
#endif
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12080)
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12080) || (defined(USE_ROCM) && ROCM_VERSION >= 70000)
|
||||
case c10::ScalarType::Float4_e2m1fn_x2:
|
||||
return CUDA_R_4F_E2M1;
|
||||
#endif
|
||||
|
@ -85,6 +85,15 @@ constexpr hipDataType HipDataTypeFor<c10::Float8_e8m0fnu>() {
|
||||
return static_cast<hipDataType>(500);
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipDataType HipDataTypeFor<c10::Float4_e2m1fn_x2>() {
|
||||
#if ROCM_VERSION >= 70000
|
||||
return HIP_R_4F_E2M1;
|
||||
#else
|
||||
return static_cast<hipDataType>(33);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int GetBatchFromParams(const GemmParams<T>* params) {
|
||||
return 1;
|
||||
|
@ -1283,15 +1283,35 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
if (use_fast_accum) {
|
||||
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float4_e2m1fn_x2 && mat2.scalar_type() != ScalarType::Float4_e2m1fn_x2, "`use_fast_accum` is not supported when `mat1` or `mat2` tensors have the `Float4_e2m1fn_x2` dtype.");
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
if (mat1.scalar_type() == ScalarType::Float4_e2m1fn_x2 || mat2.scalar_type() == ScalarType::Float4_e2m1fn_x2) {
|
||||
TORCH_CHECK(ROCM_VERSION >= 70000, "Float4_e2m1fn_x2 is only supported for ROCm 7.0 and above");
|
||||
}
|
||||
if (mat1.scalar_type() == ScalarType::Float8_e5m2 || mat2.scalar_type() == ScalarType::Float8_e5m2) {
|
||||
TORCH_CHECK(ROCM_VERSION >= 60500, "Float8_e5m2 is only supported for ROCm 6.5 and above");
|
||||
}
|
||||
if (mat1.scalar_type() == ScalarType::Float8_e4m3fn || mat2.scalar_type() == ScalarType::Float8_e4m3fn) {
|
||||
TORCH_CHECK(ROCM_VERSION >= 60500, "Float8_e4m3fn is only supported for ROCm 6.5 and above");
|
||||
}
|
||||
#endif
|
||||
if (bias) {
|
||||
TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32");
|
||||
TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half,
|
||||
"Bias must be either Half or BFloat16, but got ", bias->scalar_type());
|
||||
TORCH_CHECK((out.scalar_type() != kFloat && out.scalar_type() != ScalarType::BFloat16) ||
|
||||
bias->scalar_type() == ScalarType::BFloat16,
|
||||
"Bias must be BFloat16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type());
|
||||
TORCH_CHECK(out.scalar_type() != ScalarType::Half || bias->scalar_type() == ScalarType::Half,
|
||||
"Bias must be Float16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type());
|
||||
TORCH_CHECK(out.scalar_type() != kFloat,
|
||||
"Bias is not supported when out_dtype is set to Float32");
|
||||
|
||||
TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 ||
|
||||
bias->scalar_type() == ScalarType::Half,
|
||||
"Bias must be BFloat16 or Half, but got ", bias->scalar_type());
|
||||
|
||||
TORCH_CHECK((out.scalar_type() != kFloat &&
|
||||
out.scalar_type() != ScalarType::BFloat16) ||
|
||||
bias->scalar_type() == ScalarType::BFloat16,
|
||||
"Bias must be BFloat16 to compute ", out.scalar_type(),
|
||||
" output, but got ", bias->scalar_type());
|
||||
|
||||
TORCH_CHECK(out.scalar_type() != ScalarType::Half ||
|
||||
bias->scalar_type() == ScalarType::Half,
|
||||
"Bias must be Float16 to compute ", out.scalar_type(),
|
||||
" output, but got ", bias->scalar_type());
|
||||
}
|
||||
{
|
||||
auto bias_ = bias.value_or(Tensor());
|
||||
@ -1353,6 +1373,22 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16,
|
||||
"hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type());
|
||||
}
|
||||
else if (scaling_choice_a == ScalingType::BlockWise1x32 && scaling_choice_b == ScalingType::BlockWise1x32) {
|
||||
#if ROCM_VERSION >= 70000
|
||||
TORCH_CHECK(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
|
||||
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");
|
||||
|
||||
TORCH_CHECK(mat1.size(0) % 32 == 0 && mat1.size(1) % 32 == 0 &&
|
||||
mat2.size(0) % 32 == 0 && mat2.size(1) % 32 == 0,
|
||||
"Matrix dimensions must be multiples of 32 for block-wise scaling");
|
||||
|
||||
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16 ||
|
||||
out.scalar_type() == ScalarType::Half,
|
||||
"Block-wise scaling only supports BFloat16 or Half output types");
|
||||
#else
|
||||
TORCH_CHECK(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result, scaling_choice_a, scaling_choice_b);
|
||||
@ -1430,12 +1466,14 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
params.k = args.k;
|
||||
params.a = args.mata->data_ptr();
|
||||
params.a_scale_ptr = args.scale_mata_ptr;
|
||||
params.a_scale_dtype = args.scale_mata_dtype.value();
|
||||
params.lda = args.lda;
|
||||
params.a_dtype = args.mata->scalar_type();
|
||||
params.a_scale_dtype = args.scale_mata_dtype.value();
|
||||
params.a_scaling_type = args.scaling_mata_type.value();
|
||||
params.b = args.matb->data_ptr();
|
||||
params.b_scale_ptr = args.scale_matb_ptr;
|
||||
params.b_scale_dtype = args.scale_matb_dtype.value();
|
||||
params.ldb = args.ldb;
|
||||
params.b_dtype = args.matb->scalar_type();
|
||||
params.b_scale_dtype = args.scale_matb_dtype.value();
|
||||
|
@ -66,6 +66,7 @@ from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
serialTest,
|
||||
skipIfHpu,
|
||||
skipIfRocm,
|
||||
skipIfWindows,
|
||||
TEST_WITH_ROCM,
|
||||
)
|
||||
@ -7405,6 +7406,7 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
||||
out = f_compiled(x, s0, s1, s2)
|
||||
self.assertEqual(out_ref, out)
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "requires gpu with fp8 support")
|
||||
@requires_cuda
|
||||
def test_partitioner_saves_weights_for_bw(self):
|
||||
|
@ -918,6 +918,8 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
# largest power of 2 representable in `torch.float8_e4m3fn`
|
||||
F8E4M3_LARGEST_POW2 = 8
|
||||
# largest power of 2 representable in `torch.float4_e2m1fn_x2`
|
||||
FP4E2M1FN_LARGEST_POW2 = 1.0
|
||||
# max value of `torch.float8_e4m3fn` (448)
|
||||
F8E4M3_MAX_VAL = torch.finfo(torch.float8_e4m3fn).max
|
||||
# exponent bias of `torch.float8_e8m0fnu`
|
||||
@ -926,14 +928,20 @@ F8E8M0_EXP_BIAS = 127
|
||||
FP4_EBITS, FP4_MBITS = 2, 1
|
||||
FP4_MAX_VAL = 6.0
|
||||
|
||||
def data_to_mx_scale(x, block_size):
|
||||
def data_to_mx_scale(x, block_size, recipe):
|
||||
# simple implementation of https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
||||
# section 6.3, not all edge cases (such as NaN) are handled/tested
|
||||
if recipe == "mxfp8":
|
||||
largest_pow2 = F8E4M3_LARGEST_POW2
|
||||
elif recipe == "mxfp4":
|
||||
largest_pow2 = FP4E2M1FN_LARGEST_POW2
|
||||
else:
|
||||
raise ValueError(f"data_to_mx_scale(): Unsupported mx recipe: {recipe}")
|
||||
orig_shape = x.shape
|
||||
x = x.reshape(-1, block_size)
|
||||
max_abs = torch.amax(torch.abs(x), 1)
|
||||
largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs))
|
||||
scale_e8m0_unbiased = largest_p2_lt_max_abs - F8E4M3_LARGEST_POW2
|
||||
scale_e8m0_unbiased = largest_p2_lt_max_abs - largest_pow2
|
||||
scale_e8m0_unbiased = torch.clamp(scale_e8m0_unbiased, -1 * F8E8M0_EXP_BIAS, F8E8M0_EXP_BIAS)
|
||||
scale_e8m0_biased = scale_e8m0_unbiased + F8E8M0_EXP_BIAS
|
||||
scale_e8m0_biased = scale_e8m0_biased.to(torch.uint8)
|
||||
@ -1415,6 +1423,7 @@ class TestFP8Matmul(TestCase):
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@unittest.skipIf(torch.version.hip is not None, "Float8_e4m3fn not supported on current ROCm CI setup (MI325X)")
|
||||
@parametrize("which_dim_zero", [0, 1, 2])
|
||||
@parametrize("use_torch_compile", [False, True])
|
||||
def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None:
|
||||
@ -1553,23 +1562,24 @@ class TestFP8Matmul(TestCase):
|
||||
(127, 96, 1024),
|
||||
(1025, 128, 96)
|
||||
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
|
||||
@parametrize("recipe", ["mxfp8", "nvfp4"])
|
||||
def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
|
||||
if recipe == "nvfp4" and fast_accum:
|
||||
return unittest.skip("fast_accum not supported in nvfp4 cublas gemm, skipping")
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
|
||||
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
|
||||
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
|
||||
return unittest.skip("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
|
||||
|
||||
device = "cuda"
|
||||
M, K, N = mkn
|
||||
if recipe == "nvfp4" and K % 32 != 0:
|
||||
return unittest.skip("K must be divisible by 32 for nvfp4 cublas gemm, skipping")
|
||||
if (recipe == "nvfp4" or recipe == "mxfp4") and K % 32 != 0:
|
||||
return unittest.skip("K must be divisible by 32 for nvfp4/mxfp4 cublas gemm, skipping")
|
||||
|
||||
BLOCK_SIZE = 16 if recipe == "nvfp4" else 32
|
||||
fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn
|
||||
BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32)
|
||||
require_exact_match = True
|
||||
approx_match_sqnr_target = 22.0
|
||||
|
||||
if test_case_name == "a_eye_b_eye":
|
||||
if not ((M == K) and (M == N)):
|
||||
return unittest.skip("this test is only defined for M == K == N, skipping")
|
||||
raise unittest.SkipTest("this test is only defined for M == K == N, skipping")
|
||||
A_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
|
||||
B_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
|
||||
|
||||
@ -1578,11 +1588,11 @@ class TestFP8Matmul(TestCase):
|
||||
B = B_ref.to(torch.float8_e4m3fn)
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
else: # nvfp4
|
||||
else: # nvfp4 # mxfp4
|
||||
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
|
||||
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
|
||||
|
||||
elif test_case_name == "a_ones_b_ones":
|
||||
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
||||
@ -1593,11 +1603,11 @@ class TestFP8Matmul(TestCase):
|
||||
B = B_ref.to(torch.float8_e4m3fn)
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
else: # nvfp4
|
||||
else: # nvfp4 # mxfp4
|
||||
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
|
||||
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
|
||||
|
||||
elif test_case_name == "a_ones_modified_b_ones":
|
||||
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
||||
@ -1609,11 +1619,11 @@ class TestFP8Matmul(TestCase):
|
||||
B = B_ref.to(torch.float8_e4m3fn)
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
else: # nvfp4
|
||||
else: # nvfp4 # mxfp4
|
||||
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
|
||||
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
|
||||
|
||||
elif test_case_name == "a_ones_b_ones_modified":
|
||||
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
||||
@ -1625,11 +1635,11 @@ class TestFP8Matmul(TestCase):
|
||||
B = B_ref.to(torch.float8_e4m3fn)
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
else: # nvfp4
|
||||
else: # nvfp4 # mxfp4
|
||||
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
|
||||
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
|
||||
|
||||
elif test_case_name == "a_scale_modified_b_ones":
|
||||
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
||||
@ -1643,11 +1653,11 @@ class TestFP8Matmul(TestCase):
|
||||
A_ref[1][0:BLOCK_SIZE] = 4
|
||||
A[1][0:BLOCK_SIZE] = 2
|
||||
A_scale[1][0] = 2
|
||||
else: # nvfp4
|
||||
else: # nvfp4 # mxfp4
|
||||
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
|
||||
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
|
||||
A_ref[1][0:BLOCK_SIZE] = 4
|
||||
A.view(torch.uint8)[1][0:(BLOCK_SIZE // 2)] = 0b01000100
|
||||
A_scale[1][0] = 2
|
||||
@ -1664,11 +1674,11 @@ class TestFP8Matmul(TestCase):
|
||||
B_ref[1][0:BLOCK_SIZE] = 4
|
||||
B[1][0:BLOCK_SIZE] = 2
|
||||
B_scale[1][0] = 2
|
||||
else: # nvfp4
|
||||
else: # nvfp4 # mxfp4
|
||||
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
|
||||
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
|
||||
B_ref[1][0:BLOCK_SIZE] = 4
|
||||
B.view(torch.uint8)[1][0:(BLOCK_SIZE // 2)] = 0b01000100
|
||||
B_scale[1][0] = 2
|
||||
@ -1688,7 +1698,7 @@ class TestFP8Matmul(TestCase):
|
||||
B = B_ref.to(torch.float8_e4m3fn)
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
else: # nvfp4
|
||||
else: # nvfp4 # mxfp4
|
||||
# scales all-ones, element data random while being exactly representable in float4_e2m1fn_x2
|
||||
# generate integers in [0, 16] and cast to bfloat16
|
||||
A_ref = _floatx_unpacked_to_f32(
|
||||
@ -1703,8 +1713,8 @@ class TestFP8Matmul(TestCase):
|
||||
).bfloat16()
|
||||
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
|
||||
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
|
||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
|
||||
|
||||
elif test_case_name == "data_random_scales_from_data":
|
||||
if not K % BLOCK_SIZE == 0:
|
||||
@ -1716,17 +1726,18 @@ class TestFP8Matmul(TestCase):
|
||||
|
||||
if recipe == "mxfp8":
|
||||
# Calculate scales based on the inputs
|
||||
A_scale = data_to_mx_scale(A_ref, BLOCK_SIZE)
|
||||
B_scale = data_to_mx_scale(B_ref, BLOCK_SIZE)
|
||||
A_scale = data_to_mx_scale(A_ref, BLOCK_SIZE, recipe)
|
||||
B_scale = data_to_mx_scale(B_ref, BLOCK_SIZE, recipe)
|
||||
max_val = F8E4M3_MAX_VAL
|
||||
min_val = -1 * max_val
|
||||
A = (A_ref.reshape(-1, BLOCK_SIZE) / A_scale.reshape(M * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(M, K)
|
||||
A = A.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
|
||||
B = (B_ref.reshape(-1, BLOCK_SIZE) / B_scale.reshape(N * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(N, K)
|
||||
B = B.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
|
||||
else: # nvfp4
|
||||
A_scale = data_to_nvfp4_scale(A_ref, BLOCK_SIZE)
|
||||
B_scale = data_to_nvfp4_scale(B_ref, BLOCK_SIZE)
|
||||
else: # nvfp4 # mxfp4
|
||||
scale_func = data_to_mx_scale if recipe == "mxfp4" else data_to_nvfp4_scale
|
||||
A_scale = scale_func(A_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None)
|
||||
B_scale = scale_func(B_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None)
|
||||
max_val = FP4_MAX_VAL
|
||||
min_val = -1 * max_val
|
||||
|
||||
@ -1737,13 +1748,14 @@ class TestFP8Matmul(TestCase):
|
||||
B = B.clamp(min=min_val, max=max_val)
|
||||
B = _bfloat16_to_float4_e2m1fn_x2(B)
|
||||
|
||||
approx_match_sqnr_target = 15.8
|
||||
approx_match_sqnr_target = 12.0 if torch.version.hip else 15.8
|
||||
|
||||
C_ref = A_ref @ B_ref.t()
|
||||
|
||||
# convert to swizzled format
|
||||
A_scale = to_blocked(A_scale)
|
||||
B_scale = to_blocked(B_scale)
|
||||
if not torch.version.hip:
|
||||
A_scale = to_blocked(A_scale)
|
||||
B_scale = to_blocked(B_scale)
|
||||
|
||||
C = torch._scaled_mm(
|
||||
A,
|
||||
|
@ -120,12 +120,20 @@ def evaluate_platform_supports_fp8_grouped_gemm():
|
||||
return SM90OrLater and not SM100OrLater
|
||||
return False
|
||||
|
||||
def evaluate_platform_supports_mx_gemm():
|
||||
if torch.cuda.is_available():
|
||||
if torch.version.hip:
|
||||
ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2])
|
||||
if ROCM_VERSION >= (7, 0):
|
||||
return 'gfx950' in torch.cuda.get_device_properties(0).gcnArchName
|
||||
else:
|
||||
return SM100OrLater
|
||||
return False
|
||||
|
||||
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_mx_gemm())
|
||||
PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8())
|
||||
|
||||
PLATFORM_SUPPORTS_FP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_fp8_grouped_gemm())
|
||||
|
||||
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: TEST_CUDA and SM100OrLater)
|
||||
|
||||
if TEST_NUMBA:
|
||||
try:
|
||||
import numba.cuda
|
||||
|
@ -3999,6 +3999,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
|
||||
("CUDA_C_64U", ("HIP_C_64U", CONV_TYPE, API_RUNTIME)),
|
||||
("CUDA_R_8F_E4M3", ("HIP_R_8F_E4M3", CONV_TYPE, API_RUNTIME)),
|
||||
("CUDA_R_8F_E5M2", ("HIP_R_8F_E5M2", CONV_TYPE, API_RUNTIME)),
|
||||
("CUDA_R_4F_E2M1", ("HIP_R_4F_E2M1", CONV_TYPE, API_RUNTIME)),
|
||||
(
|
||||
"MAJOR_VERSION",
|
||||
("hipLibraryMajorVersion", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED),
|
||||
@ -7693,6 +7694,10 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
|
||||
("CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F", ("HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F", CONV_MATH_FUNC, API_BLAS)),
|
||||
("CUBLASLT_MATMUL_DESC_AMAX_D_POINTER", ("HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER", CONV_MATH_FUNC, API_BLAS)),
|
||||
("CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", ("HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", CONV_MATH_FUNC, API_BLAS)),
|
||||
("CUBLASLT_MATMUL_DESC_A_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_A_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)),
|
||||
("CUBLASLT_MATMUL_DESC_B_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_B_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)),
|
||||
("CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0", ("HIPBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0", CONV_MATH_FUNC, API_BLAS)),
|
||||
("CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3", ("HIPBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3", CONV_MATH_FUNC, API_BLAS)),
|
||||
("cublasLtMatrixLayout_t", ("hipblasLtMatrixLayout_t", CONV_MATH_FUNC, API_BLAS)),
|
||||
("cublasLtMatrixLayoutOpaque_t", ("hipblasLtMatrixLayoutOpaque_t", CONV_MATH_FUNC, API_BLAS)),
|
||||
("cublasLtMatrixLayoutAttribute_t", ("hipblasLtMatrixLayoutAttribute_t", CONV_MATH_FUNC, API_BLAS)),
|
||||
|
Reference in New Issue
Block a user