diff --git a/CMakeLists.txt b/CMakeLists.txt index 05f14edcf3a6..4120e621bdd0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -889,6 +889,12 @@ IF(USE_FBGEMM_GENAI AND USE_ROCM AND NOT "gfx942" IN_LIST PYTORCH_ROCM_ARCH) set(USE_FBGEMM_GENAI off) endif() +# Set USE_FBGEMM_GENAI to ON for CUDA build on SM100 +if(USE_CUDA AND "$ENV{TORCH_CUDA_ARCH_LIST}" MATCHES "10.0a") + message(WARNING "Setting USE_FBGEMM_GENAI to ON for CUDA build on SM100") + set(USE_FBGEMM_GENAI ON) +endif() + # CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem # Eff Attention won't cmake_dependent_option( diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 6f7482dfd066..6c095680733f 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -255,48 +255,77 @@ endif() # FBGEMM GenAI IF(USE_FBGEMM_GENAI) set(FBGEMM_THIRD_PARTY ${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/) - set(FBGEMM_GENAI_DIR ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize) + set(FBGEMM_GENAI_SRCS ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize) + if(USE_CUDA) + # To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build. + # If you want to integrate a kernel from FBGEMM into torch, you have to add it here. + set(FBGEMM_CUTLASS_KERNELS_REGEX ".*mx8mx8bf16_grouped.*") + file(GLOB_RECURSE fbgemm_genai_native_cuda_cu + "${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu" + "${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu") + list(FILTER fbgemm_genai_native_cuda_cu INCLUDE REGEX ${FBGEMM_CUTLASS_KERNELS_REGEX}) - if(USE_ROCM) - # Only include the kernels we want to build to avoid increasing binary size. - file(GLOB_RECURSE fbgemm_genai_native_rocm_hip - "${FBGEMM_GENAI_DIR}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip" - "${FBGEMM_GENAI_DIR}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip") - set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + file(GLOB_RECURSE fbgemm_genai_native_cuda_cpp + "${FBGEMM_GENAI_SRCS}/common/*.cpp" + ) - # Add additional HIPCC compiler flags for performance - set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS - -mllvm - -amdgpu-coerce-illegal-types=1 - -mllvm - -enable-post-misched=0 - -mllvm - -greedy-reverse-local-assignment=1 - -fhip-new-launch-api) + # Combine all source files into a single list + list(APPEND fbgemm_genai_all_sources + ${fbgemm_genai_native_cuda_cu} + ${fbgemm_genai_native_cuda_cpp} + ) - # Only compile for gfx942 for now. - # This is rather hacky, I could not figure out a clean solution :( - set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS}) - string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}") - list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;) - set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS}) - - hip_add_library( - fbgemm_genai STATIC - ${fbgemm_genai_native_rocm_hip} - HIPCC_OPTIONS ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS}) - set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL}) + # Now, create the library and provide the sources at the same time + add_library(fbgemm_genai OBJECT ${fbgemm_genai_all_sources}) set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON) - target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES) + + set(fbgemm_genai_mx8mx8bf16_grouped + "${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/" + ) target_include_directories(fbgemm_genai PUBLIC - # FBGEMM version of Composable Kernel is used due to some customizations - ${FBGEMM_THIRD_PARTY}/composable_kernel/include - ${FBGEMM_THIRD_PARTY}/composable_kernel/library/include - ${FBGEMM_GENAI_DIR}/include/ - ${FBGEMM_GENAI_DIR}/common/include/ + ${FBGEMM_THIRD_PARTY}/cutlass/include + ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include + ${fbgemm_genai_mx8mx8bf16_grouped} + ${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp + ${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h ) + else() + if(USE_ROCM) + # Only include the kernels we want to build to avoid increasing binary size. + file(GLOB_RECURSE fbgemm_genai_native_rocm_hip + "${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip" + "${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip") + set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + + # Add additional HIPCC compiler flags for performance + set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS + -mllvm + -amdgpu-coerce-illegal-types=1 + -mllvm + -enable-post-misched=0 + -mllvm + -greedy-reverse-local-assignment=1 + -fhip-new-launch-api) + + hip_add_library( + fbgemm_genai STATIC + ${fbgemm_genai_native_rocm_hip} + HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS}) + set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES) + + target_include_directories(fbgemm_genai PUBLIC + # FBGEMM version of Composable Kernel is used due to some customizations + ${FBGEMM_THIRD_PARTY}/composable_kernel/include + ${FBGEMM_THIRD_PARTY}/composable_kernel/library/include + ${FBGEMM_THIRD_PARTY}/cutlass/include + ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include + ${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp + ${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h + ) + endif() endif() endif() @@ -639,6 +668,13 @@ if(USE_CUDA AND NOT USE_ROCM) add_definitions(-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include) list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include) + + # Add FBGEMM_GENAI include directories for torch_ops.h + if(USE_FBGEMM_GENAI) + list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include) + list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include) + endif() + if($ENV{ATEN_STATIC_CUDA}) if(CUDA_VERSION VERSION_LESS_EQUAL 12.9) list(APPEND ATen_CUDA_DEPENDENCY_LIBS diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index e5c89df516a2..23447c7e09b3 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -1551,7 +1551,8 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, } namespace { - void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) { + void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) { + // Checks scales for 2d or 3d target tensors (`mat`). if (mat.dim() == 2) { TORCH_CHECK( scale.dim() == 1, @@ -1585,9 +1586,66 @@ namespace { "scale must have the same first dimension as mat for arg ", arg_idx); } -} + } + void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) { + // Checks scales for 2d or 3d target tensors (`mat`). + if (mat.dim() == 2) { + // For MXFP8, 2d tensors have variable size groups represented as subtensors, + // that are converted to blocked padded format individually, + // so we can't check the scale sizes without doing a d2h sync to get the group sizes here. + TORCH_CHECK( + scale.dim() == mat.dim(), + "for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx); + // LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/32, 4)) + // RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/32, 4)) + // * weight is transposed prior to the call, scale stays non-transposed. + bool LHS = arg_idx == 0; + int scale_dim_to_check = 0; + int mat_dim_to_check = LHS ? 0 : 1; + TORCH_CHECK( + scale.size(scale_dim_to_check) >= mat.size(mat_dim_to_check), + "for mxfp8, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ", + "must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")"); + } else { + // For MXFP8, 3d tensors have static group sizes (stack of 2d tensors), + // so we can check the exact expected scale sizes here without a d2h sync. + auto round_up = [](auto x, auto y) { + return ((x + y - 1) / y) * y; + }; + + // TODO: this is for 3d tensor in 2d-3d case specifically. + // We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them. + int64_t G = mat.size(0); + int64_t K = mat.size(1); + int64_t N = mat.size(2); + int64_t blocked_scale_K = round_up(K/32, 4); + int64_t blocked_scale_N = round_up(N, 128); + + // fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N). + TORCH_CHECK( + scale.dim() == mat.dim() - 1, + "for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx + ); + TORCH_CHECK( + scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N, + "for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx + ); + } + } + + void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) { + bool using_fp8_rowwise = scale.scalar_type() == kFloat; + bool using_mxfp8 = scale.scalar_type() == at::kFloat8_e8m0fnu; + if (using_fp8_rowwise) { + _check_scales_fp8_rowwise(mat, scale, dim, arg_idx, scale_multiplier); + } else if (using_mxfp8) { + _check_scales_mxfp8(mat, scale, dim, arg_idx); + } else { + TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype()); + } + } } Tensor @@ -1612,8 +1670,8 @@ const std::optional& bias, const std::optional& scale_result, std::optional out_dtype, bool use_fast_accum) { - bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/false); - TORCH_CHECK(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = 9.0, or ROCm MI300+"); + bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true); + TORCH_CHECK(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+"); TORCH_CHECK(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed"); TORCH_CHECK(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed"); @@ -1646,10 +1704,12 @@ bool use_fast_accum) { TORCH_CHECK(offs->dtype() == at::kInt, "Offsets have to be int32"); } - // Both Per-Tensor and Row-wise scaling expect fp32 tensors + // FP8 per-tensor and per-row scaling expect fp32 scales. + // MXFP8 expects float8_e8m0fnu scales. TORCH_CHECK( - scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat, - "Both scale_a and scale_b must be float (fp32) tensors."); + (scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat) || + (scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu), + "For FP8 tensorwise and rowwise, both scales must both be float32 tensors. For MXFP8, scales must both be float8_e8m0fnu tensors."); const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1; check_scale(mat_a, scale_a, 0 ,0, scale_multiplier); @@ -1660,6 +1720,32 @@ bool use_fast_accum) { Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_); +#if defined(USE_FBGEMM_GENAI) && defined(USE_CUDA) && !defined(USE_ROCM) + // MXFP8 grouped GEMM dispatching + bool is_mx8mx8bf16 = ( + mat_a.scalar_type() == at::kFloat8_e4m3fn && mat_b.scalar_type() == at::kFloat8_e4m3fn && + scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu + ); + TORCH_CHECK(out_dtype == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm"); + + if (is_mx8mx8bf16) { + bool b_is_3d = mat_b.dim() == 3; + bool is_2d_2d = a_is_2d && b_is_2d; + bool is_2d_3d = a_is_2d && b_is_3d; + TORCH_CHECK(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases"); + TORCH_CHECK(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets"); + + fbgemm_gpu::mx8mx8bf16_grouped_mm( + mat_a, + mat_b, + scale_a, + scale_b, + offs.value(), + out); + return out; + } +#endif + #ifndef USE_ROCM TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type()); TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type()); @@ -1691,6 +1777,7 @@ bool use_fast_accum) { #else TORCH_CHECK(false, "grouped gemm is not supported without USE_FBGEMM_GENAI on ROCM") #endif + #endif } diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 378cb73a225e..504dbf5a4fad 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1638,6 +1638,10 @@ if(USE_CUDA) # order of the libraries in the linker call matters here when statically # linking; libculibos and cublas must be last. target_link_libraries(torch_cuda PUBLIC torch_cpu_library ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS}) + if(USE_FBGEMM_GENAI) + # Link fbgemm_genai to torch_cuda (only for (1) CUDA build for SM100). + target_link_libraries(torch_cuda PRIVATE fbgemm_genai) + endif() endif() # ---[ XPU library. @@ -1759,9 +1763,10 @@ if(USE_ROCM) target_link_libraries(torch_hip PRIVATE ${Caffe2_HIP_DEPENDENCY_LIBS}) if(USE_FBGEMM_GENAI) - target_link_libraries(torch_hip PRIVATE fbgemm_genai) + if(USE_ROCM) + target_link_libraries(torch_hip PRIVATE fbgemm_genai) + endif() endif() - # Since PyTorch files contain HIP headers, this is also needed to capture the includes. # ROCM_INCLUDE_DIRS is defined in LoadHIP.cmake target_include_directories(torch_hip PRIVATE ${Caffe2_HIP_INCLUDE} ${ROCM_INCLUDE_DIRS}) diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 6935c5e902bb..4b42637fde66 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -30,6 +30,7 @@ from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FP8, PLATFORM_SUPPORTS_FP8_GROUPED_GEMM, PLATFORM_SUPPORTS_MX_GEMM, + PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, IS_SM90, ) from torch.testing._internal.common_device_type import ( @@ -55,7 +56,13 @@ from torch.testing._internal.common_utils import ( TEST_WITH_ROCM, TestCase, ) -from torch.testing._internal.common_quantized import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, ceil_div, to_blocked +from torch.testing._internal.common_quantized import ( + _f32_to_floatx_unpacked, + _floatx_unpacked_to_f32, + ceil_div, to_blocked, + to_mxfp8, + generate_jagged_offs, +) _IS_SM8X = False if TEST_CUDA: @@ -771,6 +778,7 @@ class TestMatmulCuda(TestCase): f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices" f8_grouped_msg = "FP8 grouped is only supported on SM90 and MI300+ devices" mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+" +mxfp8_grouped_mm_skip_msg = "MXFP8 grouped GEMM is only supported when PyTorch is built with USE_FBGEMM_GENAI=1 on SM100+" # avoid division by zero when calculating scale EPS = 1e-12 @@ -901,6 +909,8 @@ def to_fp8_saturated( return x.to(fp8_dtype) + + def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Computes the error between two tensors in dB. @@ -1045,6 +1055,167 @@ class TestFP8Matmul(TestCase): out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) self.assertEqual(out_fp8, out_fp8_s) + @unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg) + @parametrize("G", [1, 4, 16]) + @parametrize("M", [2048, 2049]) + @parametrize("N", [8192]) + @parametrize("K", [16640]) + def test_mxfp8_scaled_grouped_mm_2d_2d(self, G, M, N, K): + torch.manual_seed(42) + total_K = K # Alias for clarity, communicating this consists of several groups along this dim + input_group_end_offsets = generate_jagged_offs( + G, total_K, multiple_of=32, device="cuda" + ) + X = torch.randn((M, total_K), dtype=torch.bfloat16, device="cuda") * 0.1 + W = torch.randn((N, total_K), dtype=torch.bfloat16, device="cuda") * 0.01 + + # Convert scales to blocked format. + x_list = [] + w_list = [] + x_blocked_scale_list = [] + w_blocked_scale_list = [] + + def round_up(x: int, y: int) -> int: + return ((x + y - 1) // y) * y + + for group_idx in range(G): + # to_mxfp8 per group + prev_group_end_offset = ( + 0 if group_idx == 0 else input_group_end_offsets[group_idx - 1] + ) + curr_group_end_offset = input_group_end_offsets[group_idx] + group_size = curr_group_end_offset - prev_group_end_offset + if group_size > 0: + x_slice = X[ + :, prev_group_end_offset:curr_group_end_offset + ].contiguous() # (M, K_group) + w_slice = W[ + :, prev_group_end_offset:curr_group_end_offset + ].contiguous() # (N, K_group) + x_scale_slice, xq_slice = to_mxfp8( + x_slice + ) # scale shape -> (M, K_group // 32) + w_scale_slice, wq_slice = to_mxfp8( + w_slice + ) # scale shape -> (N, K_group // 32) + x_list.append(xq_slice) + w_list.append(wq_slice) + + # Convert scales to blocked format. + x_scale_slice_blocked = to_blocked( + x_scale_slice + ) # (round_up(M, 128), round_up(K_group//32, 4)) + w_scale_slice_blocked = to_blocked( + w_scale_slice + ) # (round_up(N, 128), round_up(K_group//32, 4)) + x_blocked_scale_list.append(x_scale_slice_blocked) + w_blocked_scale_list.append(w_scale_slice_blocked) + + # Assemble the full XQ and WQ + xq = torch.cat(x_list, dim=1).contiguous() + wq = torch.cat(w_list, dim=1).contiguous() + + # Combine all XQ groups blocked scales into one tensor. + x_blocked_scales = torch.cat(x_blocked_scale_list, dim=0) + M_rounded = round_up(M, 128) + x_blocked_scales = x_blocked_scales.reshape(M_rounded, -1) + + # Combine all WQ groups blocked scales into one tensor. + w_blocked_scales = torch.cat(w_blocked_scale_list, dim=0) + N_rounded = round_up(N, 128) + w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1) + + # Compute mxfp8 grouped mm output + y_mxfp8 = torch._scaled_grouped_mm( + xq, # (M, total_K) + wq.transpose(-2, -1), # (total_K, N) + x_blocked_scales, # to_blocked_per_group(M, total_K//32) + w_blocked_scales, # to_blocked_per_group(N, total_K//32) + offs=input_group_end_offsets, # (G,) + out_dtype=torch.bfloat16, + ) + + # bf16 reference output + y_bf16 = torch._grouped_mm( + X, W.t(), offs=input_group_end_offsets, out_dtype=torch.bfloat16 + ) + + # Assert no NaNs + assert not y_mxfp8.isnan().any(), "mxfp8 output contains NaN" + + # Assert outputs are close + torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2) + + @unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg) + @parametrize("G", [1, 4, 16]) + @parametrize("M", [16640]) + @parametrize("N", [8192]) + @parametrize("K", [4096]) + def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K): + torch.manual_seed(42) + # Simulate 2d-3d grouped gemm `out = input @ weight.t()` + # 2D inputs with groups along M, 3D weights. + block_size = 32 + total_M = M # Alias for clarity that M dim contains groups. + X = torch.randn((total_M, K), dtype=torch.bfloat16, device="cuda") * 0.1 + W = torch.randn((G, N, K), dtype=torch.bfloat16, device="cuda") * 0.01 + input_group_end_offsets = generate_jagged_offs( + G, total_M, multiple_of=32, device="cuda" + ) + + # For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately, + # as they each used for independent gemm in the grouped gemm. + wq_list = [] + w_scale_list = [] + for i in range(G): + w_scale, wq = to_mxfp8(W[i]) + w_scale = to_blocked(w_scale) + wq_list.append(wq) + w_scale_list.append(w_scale) + wq = torch.stack(wq_list, dim=0).contiguous() + w_scale = torch.stack(w_scale_list, dim=0).contiguous() + + # For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately, + # as they each used for independent gemm in the grouped gemm. + xq_list = [] + x_scale_list = [] + for i in range(G): + prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1] + curr_group_end = input_group_end_offsets[i] + group_size = curr_group_end - prev_group_end + if group_size > 0: + x_slice = X[prev_group_end:curr_group_end, :] + x_scale, xq = to_mxfp8(x_slice) + x_scale = to_blocked(x_scale) + xq_list.append(xq) + x_scale_list.append(x_scale) + xq = torch.cat(xq_list, dim=0).contiguous() + x_scale = torch.cat(x_scale_list, dim=0).contiguous() + x_scale = x_scale.reshape(-1, K // block_size) + xq = xq.view(-1, xq.shape[-1]) + + # Compute mxfp8 grouped gemm. + y_mxfp8 = torch._scaled_grouped_mm( + xq, + wq.transpose(-2, -1), + x_scale, + w_scale, + offs=input_group_end_offsets, + out_dtype=torch.bfloat16, + ) + + # Compute reference bf16 grouped gemm. + y_bf16 = torch._grouped_mm( + X, + W.transpose(-2, -1), + offs=input_group_end_offsets, + out_dtype=torch.bfloat16, + ) + + # Assert outputs are close. + torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2) + + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32]) def test_scaled_mm_vs_emulated(self, base_dtype): diff --git a/third_party/fbgemm b/third_party/fbgemm index 21c7d30c526c..4b39c551efe1 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit 21c7d30c526c0f1ad873ecc632dca6cfa8a69067 +Subproject commit 4b39c551efe15e6bbade20565b0ceb2d8ce3352d diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index d1c3b42d9fa8..7a0301371b11 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -7424,17 +7424,17 @@ def _meta_grouped_mm_common( fp8_dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn torch._check( mat_a.dtype == fp8_dtype and mat_b.dtype == fp8_dtype, - lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", + lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", # noqa: B950 ) else: torch._check( mat_a.dtype == torch.bfloat16 and mat_b.dtype == torch.bfloat16, - lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", + lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", # noqa: B950 ) torch._check( mat_a.dim() in [2, 3] and mat_b.dim() in [2, 3], - lambda: f"Multiplicands must be 2D or 3D but got mat_a.dim()={mat_a.dim()} and mat_b.dim()={mat_b.dim()}", + lambda: f"Multiplicands must be 2D or 3D but got mat_a.dim()={mat_a.dim()} and mat_b.dim()={mat_b.dim()}", # noqa: B950 ) mat_a_is_2d = mat_a.dim() == 2 @@ -7458,11 +7458,11 @@ def _meta_grouped_mm_common( torch._check( is_row_major(mat_a), - lambda: f"Expected mat_a tensor to be row major in the last two dimensions, got strides {mat_a.stride()[-2:]}", + lambda: f"Expected mat_a tensor to be row major in the last two dimensions, got strides {mat_a.stride()[-2:]}", # noqa: B950 ) torch._check( is_col_major(mat_b), - lambda: f"Expected mat_b tensor to be column major in the last two dimensions, got strides {mat_b.stride()[-2:]}", + lambda: f"Expected mat_b tensor to be column major in the last two dimensions, got strides {mat_b.stride()[-2:]}", # noqa: B950 ) def check_valid_strides(mat_name, mat): @@ -7474,7 +7474,7 @@ def _meta_grouped_mm_common( ): torch._check( mat_stride[end_dim] % alignment == 0, - lambda: f"Expected {mat_name} stride along {end_dim} dim to be multiple of 16 bytes, got {mat_stride[end_dim]}.", + lambda: f"Expected {mat_name} stride along {end_dim} dim to be multiple of 16 bytes, got {mat_stride[end_dim]}.", # noqa: B950 ) elif mat_stride[end_dim] == 1 and mat_stride[end_dim - 1] >= max( 1, mat.shape[end_dim] @@ -7494,41 +7494,81 @@ def _meta_grouped_mm_common( if scale_a is not None and scale_b is not None: torch._check( - scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32, - lambda: "Both scale_a and scale_b must be float (fp32) tensors, but got scale_a.dtype={scale_a.dtype} and scale_b.dtype={scale_b.dtype}.", # noqa: B950 + (scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32) + or ( + scale_a.dtype == torch.float8_e8m0fnu + and scale_b.dtype == torch.float8_e8m0fnu + ), + lambda: f"For FP8 scales must both be float32, or for MXFP8 both scales must be float8_e8m0fnu. Got scale_a.dtype={scale_a.dtype} and scale_b.dtype={scale_b.dtype}.", # noqa: B950 ) + is_mxfp8 = ( + scale_a.dtype == torch.float8_e8m0fnu + and scale_b.dtype == torch.float8_e8m0fnu + ) + + def round_up(x, y): + """Rounds up x to nearest multiple of y""" + return ((x + y - 1) // y) * y def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1): if mat.dim() == 2: - torch._check( - scale.dim() == 1, - lambda: f"Expected {scale_name} to be 1D tensor, but got {scale.dim()}D tensor.", - ) torch._check( scale.is_contiguous(), lambda: f"Expected {scale_name} to be contiguous.", ) - torch._check( - scale.shape[0] == mat.shape[scaled_dim] * scale_multiplier, - lambda: f"Expected {scale_name} to have {mat.shape[scaled_dim] * scale_multiplier} elements, got {scale.shape[0]} elements.", # noqa: B950 - ) + # For MXFP8, 2d tensors have variable size groups represented as subtensors, + # that are converted to blocked padded format individually. At compile time we don't know + # the group sizes yet, so we don't know the expect size of the blocked format scale. + # This limits what we can check here. + if is_mxfp8: + torch._check( + scale.dim() == mat.dim(), + lambda: f"For MXFP8, scale must have same number of dimensions as target tensor, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950 + ) + else: + torch._check( + scale.dim() == 1, + lambda: f"Expected {scale_name} to be 1D tensor, but got {scale.dim()}D tensor.", + ) + torch._check( + scale.shape[0] == mat.shape[scaled_dim] * scale_multiplier, + lambda: f"Expected {scale_name} to have {mat.shape[scaled_dim] * scale_multiplier} elements, got {scale.shape[0]} elements.", # noqa: B950 + ) else: torch._check( - scale.dim() == 2, - lambda: f"Expected {scale_name} to be 2D tensor, but got {scale.dim()}D tensor.", - ) - torch._check( - scale.stride(1) == 1, + scale.stride(-1) == 1, lambda: f"Expected {scale_name} to be contiguous in the last dimension.", ) torch._check( scale.shape[0] == mat.shape[0], lambda: f"Expected {scale_name} batch dimension to be {mat.shape[0]}, got {scale.shape[0]}.", ) - torch._check( - scale.shape[1] == mat.shape[1 + scaled_dim], - lambda: f"Expected {scale_name} non-batch dimension to be {mat.shape[1 + scaled_dim]}, got {scale.shape[1]}.", - ) + # For MXFP8, 3d tensors have static 'groups' (stack of 2d tensors) so we can know the expected blocked + # scale sizes at compile time. + if is_mxfp8: + torch._check( + mat.ndim == scale.ndim, + lambda: f"For MXFP8, scale should have same number of dimensions as target tensor, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950 + ) + # TODO: This logic only holds for RHS tensor in 2d-3d case. + # We'll need to update it to handle LHS 3d tensor in 3d-2d and 3d-3d cases. + G, K, N = scale.shape + block_size = 32 + blocked_K = round_up(K / block_size, 4) + blocked_N = round_up(N, 128) + torch._check( + mat.shape[-2] == blocked_K and mat.shape[-1] == blocked_N, + lambda: f"For MXFP8, expected mat.shape={mat.shape} to have scale shape of ({G},{blocked_K},{blocked_N}), but got {scale.shape}", # noqa: B950 + ) + else: + torch._check( + scale.dim() == 2, + lambda: f"Expected {scale_name} to be 2D tensor, but got {scale.dim()}D tensor.", + ) + torch._check( + scale.shape[1] == mat.shape[1 + scaled_dim], + lambda: f"Expected {scale_name} non-batch dimension to be {mat.shape[1 + scaled_dim]}, got {scale.shape[1]}.", # noqa: B950 + ) scale_multiplier = ( offs.shape[0] if offs is not None and mat_a_is_2d and mat_b_is_2d else 1 diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 1616e675b32c..be284429114f 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -41,6 +41,7 @@ IS_THOR = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_ca IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and (torch.cuda.get_device_capability() in [(7, 2), (8, 7)] or IS_THOR)) IS_SM89 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (8, 9)) IS_SM90 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)) +IS_SM100 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (10, 0)) def evaluate_gfx_arch_within(arch_list): if not torch.cuda.is_available(): @@ -129,9 +130,17 @@ def evaluate_platform_supports_mx_gemm(): return SM100OrLater return False +def evaluate_platform_supports_mxfp8_grouped_gemm(): + if torch.cuda.is_available() and not torch.version.hip: + built_with_fbgemm_genai = "USE_FBGEMM_GENAI" in torch.__config__.show() + return built_with_fbgemm_genai and IS_SM100 + return False + PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_mx_gemm()) PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8()) PLATFORM_SUPPORTS_FP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_fp8_grouped_gemm()) +PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: TEST_CUDA and SM100OrLater) +PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_mxfp8_grouped_gemm()) if TEST_NUMBA: try: diff --git a/torch/testing/_internal/common_quantized.py b/torch/testing/_internal/common_quantized.py index 9dc177a7899b..0dc9d4cb3db7 100644 --- a/torch/testing/_internal/common_quantized.py +++ b/torch/testing/_internal/common_quantized.py @@ -479,3 +479,110 @@ def to_blocked(input_matrix) -> torch.Tensor: rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) return rearranged.flatten() + +# This function is extracted from https://github.com/pytorch/ao/blob/v0.12.0/torchao/prototype/mx_formats/mx_tensor.py#L142 +def to_mxfp8( + data_hp: torch.Tensor, + block_size: int = 32, +): + assert data_hp.dtype in ( + torch.bfloat16, + torch.float, + ), f"{data_hp.dtype} is not supported yet" + assert ( + data_hp.shape[-1] % block_size == 0 + ), f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}" + assert data_hp.is_contiguous(), "unsupported" + + orig_shape = data_hp.shape + data_hp = data_hp.reshape( + *orig_shape[:-1], orig_shape[-1] // block_size, block_size + ) + + max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1) + + data_hp = data_hp.to(torch.float32) + max_abs = max_abs.to(torch.float32) + + F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0 + max_pos = F8E4M3_MAX + + # RCEIL + def _to_mx_rceil( + data_hp: torch.Tensor, + max_abs: torch.Tensor, + max_pos: float, + ) -> tuple[torch.Tensor, torch.Tensor]: + E8M0_EXPONENT_BIAS = 127 + descale = max_abs / max_pos + exponent = torch.where( + torch.isnan(descale), + 0xFF, # Handle biased exponent for nan + # NOTE: descale < (torch.finfo(torch.float32).smallest_normal / 2) is handled through clamping + ( + torch.clamp( + torch.ceil(torch.log2(descale)), + min=-E8M0_EXPONENT_BIAS, + max=E8M0_EXPONENT_BIAS, + ) + + E8M0_EXPONENT_BIAS + ).to(torch.uint8), + ) + + descale_fp = torch.where( + exponent == 0, + 1.0, + torch.exp2(E8M0_EXPONENT_BIAS - exponent.to(torch.float32)), + ) + + # scale and saturated cast the data elements to max of target dtype + data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos) + return exponent, data_lp + + scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos) + + # cast to target dtype + data_lp = data_lp.to(torch.float8_e4m3fn) + # need to reshape at the end to help inductor fuse things + data_lp = data_lp.reshape(orig_shape) + + scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu) + scale_e8m0_biased = scale_e8m0_biased.squeeze(-1) + return scale_e8m0_biased, data_lp + +# Source: https://github.com/pytorch/ao/blob/568c1932a16ae9f30d48da214a88dc0013e98ed8/torchao/prototype/moe_training/utils.py#L310 +def generate_jagged_offs(E, M, multiple_of=16, dtype=torch.int32, device="cuda"): + """ + Utility function for tests and benchmarks. + + Generates a tensor of length E, containing random values divisible by `multiple_of`, + from 0 to M, in sorted order, and where the final value in the tensor is always M. + Args: + E (int): The length of the tensor. + M (int): The maximum value in the tensor. + Returns: + torch.Tensor: A tensor of length E with the specified properties. + """ + import random + + # Ensure M is divisible by 16 + if M % multiple_of != 0: + raise ValueError(f"M must be divisible by {multiple_of}") + + # Generate a list of possible values + possible_values = list(range(multiple_of, M + 1, multiple_of)) + + # If E is larger than the number of possible values, raise an error + if E > len(possible_values): + raise ValueError("E cannot be larger than the number of possible values") + + # Randomly select E - 1 values from the possible values (excluding M) + selected_values = torch.tensor(random.sample(possible_values[:-1], E - 1)) + + # Append M to the selected values + selected_values = torch.cat((selected_values, torch.tensor([M]))) + + # Sort the selected values + selected_values, _ = torch.sort(selected_values) + + return selected_values.to(dtype).to(device)