diff --git a/CMakeLists.txt b/CMakeLists.txt index 63a2f74404c1..119d845f7391 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -872,6 +872,14 @@ cmake_dependent_option( "USE_CUDA OR USE_ROCM;NOT MSVC" OFF) +cmake_dependent_option( + USE_FBGEMM_GENAI + "Whether to build FBGEMM GenAI quantized GEMM kernels.\ + Will be disabled if not supported by the platform" + OFF + "USE_CUDA OR USE_ROCM" + OFF) + # CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem # Eff Attention won't cmake_dependent_option( @@ -905,6 +913,10 @@ if(USE_FBGEMM) string(APPEND CMAKE_CXX_FLAGS " -DUSE_FBGEMM") endif() +if(USE_FBGEMM_GENAI) + string(APPEND CMAKE_CXX_FLAGS " -DUSE_FBGEMM_GENAI") +endif() + if(USE_PYTORCH_QNNPACK) string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_QNNPACK") endif() diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 4bf595740879..2938e690d491 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -247,6 +247,50 @@ if(USE_MEM_EFF_ATTENTION) list(APPEND ATen_ATTENTION_KERNEL_SRCS ${mem_eff_attention_cuda_kernels_cu}) endif() +IF(USE_FBGEMM_GENAI AND USE_ROCM AND NOT "gfx942" IN_LIST PYTORCH_ROCM_ARCH) + message(WARNING "Unsupported ROCM arch for FBGEMM GenAI, will set USE_FBGEMM_GENAI to OFF") + set(USE_FBGEMM_GENAI off) +endif() + +# FBGEMM GenAI +IF(USE_FBGEMM_GENAI) + set(FBGEMM_THIRD_PARTY ${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/) + set(FBGEMM_GENAI_DIR ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize) + + if(USE_ROCM) + # Only include the kernels we want to build to avoid increasing binary size. + file(GLOB_RECURSE fbgemm_genai_native_rocm_hip + "${FBGEMM_GENAI_DIR}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip" + "${FBGEMM_GENAI_DIR}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip") + set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + + # Add additional HIPCC compiler flags for performance + set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS + -mllvm + -amdgpu-coerce-illegal-types=1 + -mllvm + -enable-post-misched=0 + -mllvm + -greedy-reverse-local-assignment=1 + -fhip-new-launch-api) + + hip_add_library( + fbgemm_genai STATIC + ${fbgemm_genai_native_rocm_hip} + HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS}) + set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES) + + target_include_directories(fbgemm_genai PUBLIC + # FBGEMM version of Composable Kernel is used due to some customizations + ${FBGEMM_THIRD_PARTY}/composable_kernel/include + ${FBGEMM_THIRD_PARTY}/composable_kernel/library/include + ${FBGEMM_GENAI_DIR}/include/ + ${FBGEMM_GENAI_DIR}/common/include/ + ) + endif() +endif() + # XNNPACK file(GLOB native_xnnpack "native/xnnpack/*.cpp") diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 301f39168d0c..5317ab75ba08 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -21,6 +21,10 @@ #include #include +#ifdef USE_FBGEMM_GENAI +#include +#endif + #ifndef AT_PER_OPERATOR_HEADERS #include #include @@ -1216,7 +1220,7 @@ std::pair get_joint_scaling( // - `scale_a`: a tensor with the inverse scale of `mat1`, whose shape/strides/dtype depend on the scaling scheme // - `scale_b`: a tensor with the inverse scale of `mat2`, whose shape/strides/dtype depend on the scaling scheme // - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type -// - `use_fast_accum`: if true, enables fast float8 accumulation +// - `use_fast_accum`: if true, enables fast float8 accumulation. Backends may ignore this option if not applicable. // - `out`: a reference to the output tensor Tensor& @@ -1525,6 +1529,7 @@ namespace { const auto out_dtype_ = out_dtype.value_or(kBFloat16); TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm"); + #ifndef USE_ROCM // For TMA transfers, strides of output tensor have to be either // 1, or aligned to 16 bytes. const auto last_dim = out_size.size() - 1; @@ -1536,9 +1541,10 @@ namespace { } else { out_stride = {out_size[1] * size_padded, size_padded, 1}; } - auto out = at::empty_strided(out_size, out_stride, mat_a.options().dtype(out_dtype_)); - - return out; + return at::empty_strided(out_size, out_stride, mat_a.options().dtype(out_dtype_)); + #else + return at::empty(out_size, mat_a.options().dtype(out_dtype_)); + #endif } bool check_valid_strides_and_return_transposed(const Tensor& mat) { @@ -1619,12 +1625,9 @@ const std::optional& bias, const std::optional& scale_result, std::optional out_dtype, bool use_fast_accum) { -#ifndef USE_ROCM - bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true); - TORCH_CHECK(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = 9.0"); + bool allowed_device = _scaled_mm_allowed_device(); + TORCH_CHECK(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = 9.0, or ROCm MI300+"); - TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type()); - TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type()); TORCH_CHECK(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed"); TORCH_CHECK(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed"); TORCH_CHECK(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d"); @@ -1664,6 +1667,10 @@ bool use_fast_accum) { Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype); +#ifndef USE_ROCM + TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type()); + TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type()); + at::cuda::detail::f8f8bf16_grouped_mm( mat_a, mat_b, @@ -1674,12 +1681,23 @@ bool use_fast_accum) { use_fast_accum, out); return out; - - - - #else - TORCH_CHECK(false, "grouped gemm is not supported on ROCM") +#ifdef USE_FBGEMM_GENAI + TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_a.scalar_type()); + TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_b.scalar_type()); + + fbgemm_gpu::f8f8bf16_rowwise_grouped_mm( + mat_a, + // FBGEMM expects B matrix shape to be (.., N, K) + mat_b.transpose(-2, -1), + scale_a, + scale_b, + offs, + out); + return out; +#else + TORCH_CHECK(false, "grouped gemm is not supported without USE_FBGEMM_GENAI on ROCM") +#endif #endif } diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 776688dccad5..db10db0ea7c0 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1771,6 +1771,10 @@ if(USE_ROCM) target_link_libraries(torch_hip PUBLIC torch_cpu_library ${Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS}) target_link_libraries(torch_hip PRIVATE ${Caffe2_HIP_DEPENDENCY_LIBS}) + if(USE_FBGEMM_GENAI) + target_link_libraries(torch_hip PRIVATE fbgemm_genai) + endif() + # Since PyTorch files contain HIP headers, this is also needed to capture the includes. # ROCM_INCLUDE_DIRS is defined in LoadHIP.cmake target_include_directories(torch_hip PRIVATE ${Caffe2_HIP_INCLUDE} ${ROCM_INCLUDE_DIRS}) diff --git a/setup.py b/setup.py index b222e674c0ca..189a78c23bbb 100644 --- a/setup.py +++ b/setup.py @@ -58,6 +58,9 @@ # USE_FBGEMM=0 # disables the FBGEMM build # +# USE_FBGEMM_GENAI=1 +# enables the FBGEMM GenAI kernels to build +# # USE_KINETO=0 # disables usage of libkineto library for profiling # diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 6ee099f47d40..96a51c33386b 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -27,6 +27,7 @@ from torch.testing._internal.common_cuda import ( xfailIfSM120OrLater, _get_torch_cuda_version, PLATFORM_SUPPORTS_FP8, + PLATFORM_SUPPORTS_FP8_GROUPED_GEMM, PLATFORM_SUPPORTS_MX_GEMM, IS_SM90, ) @@ -768,6 +769,7 @@ class TestMatmulCuda(TestCase): torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accum f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices" +f8_grouped_msg = "FP8 grouped is only supported on SM90 and MI300+ devices" mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+" # avoid division by zero when calculating scale @@ -1845,17 +1847,16 @@ class TestFP8Matmul(TestCase): # _scaled_mm() already has more combinations of parameters than # _scaled_grouped_mm(), for supporting more than one inputs layout # combinations. - - @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") - @xfailIfSM100OrLater - @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8_GROUPED_GEMM, f8_grouped_msg) @parametrize("fast_accum", [False, True]) - @parametrize("strided", [False, True]) + # AMD does not support non-contiguous inputs yet + @parametrize("strided", [False] + ([True] if torch.version.cuda else [])) def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided): device = "cuda" + fp8_dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn m, n, k, n_groups = 16, 32, 64, 4 - a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups] - b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups] + a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(fp8_dtype)[:, :k * n_groups] + b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(fp8_dtype)[:, :k * n_groups] scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32) scale_b = torch.rand(n * n_groups, device=device, dtype=torch.float32) offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32) @@ -1874,17 +1875,17 @@ class TestFP8Matmul(TestCase): self.scaled_grouped_mm_helper(alist, blist, ascalelist, bscalelist, out, fast_accum) - @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") - @xfailIfSM100OrLater - @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8_GROUPED_GEMM, f8_grouped_msg) @parametrize("fast_accum", [False, True]) - @parametrize("strided", [False, True]) + # AMD does not support non-contiguous inputs yet + @parametrize("strided", [False] + ([True] if torch.version.cuda else [])) def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided): device = "cuda" + fp8_dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn m, n, k, n_groups = 16, 32, 64, 4 s_int = int(strided) - a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k] - b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k] + a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(fp8_dtype)[:, :k] + b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(fp8_dtype)[::(1 + s_int), :, :k] self.assertTrue(a.is_contiguous() is not strided) self.assertTrue(b.is_contiguous() is not strided) for check_zero_size in (True, False): @@ -1896,7 +1897,6 @@ class TestFP8Matmul(TestCase): offs[0] = offs[1] scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32) scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n) - f = torch._scaled_grouped_mm out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs, out_dtype=torch.bfloat16, use_fast_accum=fast_accum) @@ -1912,17 +1912,17 @@ class TestFP8Matmul(TestCase): self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum) - @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") - @xfailIfSM100OrLater - @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8_GROUPED_GEMM, f8_grouped_msg) @parametrize("fast_accum", [False, True]) - @parametrize("strided", [False, True]) + # AMD does not support non-contiguous inputs yet + @parametrize("strided", [False] + ([True] if torch.version.cuda else [])) def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided): device = "cuda" + fp8_dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn m, n, k, n_groups = 16, 32, 64, 4 s_int = int(strided) - a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k] - b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k] + a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(fp8_dtype)[::(1 + s_int), :, :k] + b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(fp8_dtype)[::(1 + s_int), :, :k] self.assertTrue(a.is_contiguous() is not strided) self.assertTrue(b.is_contiguous() is not strided) scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m) @@ -1935,17 +1935,17 @@ class TestFP8Matmul(TestCase): self.scaled_grouped_mm_helper(a, b, scale_a, scale_b, out, fast_accum) - @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") - @xfailIfSM100OrLater - @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8_GROUPED_GEMM, f8_grouped_msg) @parametrize("fast_accum", [False, True]) - @parametrize("strided", [False, True]) + # AMD does not support non-contiguous inputs yet + @parametrize("strided", [False] + ([True] if torch.version.cuda else [])) def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided): device = "cuda" + fp8_dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn m, n, k, n_groups = 16, 32, 64, 4 s_int = int(strided) - a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k] - b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k] + a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(fp8_dtype)[::(1 + s_int), :, :k] + b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(fp8_dtype)[:, :k] self.assertTrue(a.is_contiguous() is not strided) self.assertTrue(b.is_contiguous() is not strided) scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m) diff --git a/third_party/fbgemm b/third_party/fbgemm index 157e88b750c4..0adf628317e0 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit 157e88b750c452bef2ab4653fe9d1eeb151ce4c3 +Subproject commit 0adf628317e0cea414f66dcca901e0b85280fdb1 diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 07f22eab3f01..a61fc8559357 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -7313,13 +7313,18 @@ def _create_grouped_mm_output_tensor(mat1, mat2, offs, out_dtype): out_dtype = out_dtype or mat1.dtype - alignment = 16 // out_dtype.itemsize - size_padded = (out_size[-1] + alignment - 1) // alignment * alignment - if mat1_is_2d == mat2_is_2d: - out_stride = [out_size[1] * size_padded, size_padded, 1] + if torch.version.cuda: + alignment = 16 // out_dtype.itemsize + size_padded = (out_size[-1] + alignment - 1) // alignment * alignment + if mat1_is_2d == mat2_is_2d: + out_stride = [out_size[1] * size_padded, size_padded, 1] + else: + out_stride = [size_padded, 1] + out = torch.empty_strided( + out_size, out_stride, dtype=out_dtype, device=mat1.device + ) else: - out_stride = [size_padded, 1] - out = torch.empty_strided(out_size, out_stride, dtype=out_dtype, device=mat1.device) + out = torch.empty(out_size, dtype=out_dtype, device=mat1.device) return out @@ -7345,8 +7350,9 @@ def _meta_grouped_mm_common( # aten/src/ATen/native/cuda/Blas.cpp. if scaled: + fp8_dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn torch._check( - mat_a.dtype == torch.float8_e4m3fn and mat_b.dtype == torch.float8_e4m3fn, + mat_a.dtype == fp8_dtype and mat_b.dtype == fp8_dtype, lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", ) else: diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 9448d29baee2..0e95db1fdf37 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -107,8 +107,23 @@ def evaluate_platform_supports_fp8(): return SM90OrLater or torch.cuda.get_device_capability() == (8, 9) return False +def evaluate_platform_supports_fp8_grouped_gemm(): + if torch.cuda.is_available(): + if torch.version.hip: + if "USE_FBGEMM_GENAI" not in torch.__config__.show(): + return False + archs = ['gfx942'] + for arch in archs: + if arch in torch.cuda.get_device_properties(0).gcnArchName: + return True + else: + return SM90OrLater and not SM100OrLater + return False + PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8()) +PLATFORM_SUPPORTS_FP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_fp8_grouped_gemm()) + PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: TEST_CUDA and SM100OrLater) if TEST_NUMBA: