Build RowwiseScaledMM.cu for SM89 (#145676)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145676
Approved by: https://github.com/drisspg, https://github.com/malfet, https://github.com/eqy
This commit is contained in:
Aleksandar Samardžić
2025-01-31 23:48:47 +01:00
committed by PyTorch MergeBot
parent f40e013787
commit 2b00d211f0
10 changed files with 57 additions and 67 deletions

View File

@ -315,9 +315,6 @@ void f8f8bf16_rowwise_impl_sm89(
using LayoutInputB = cutlass::layout::ColumnMajor;
constexpr int AlignmentInputB = 16 / sizeof(DtypeB);
constexpr int AlignmentScale = 16 / sizeof(DtypeScale);
constexpr int AlignmentBias = 16 / sizeof(DtypeBias);
using LayoutOutput = cutlass::layout::RowMajor;
constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput);
@ -330,8 +327,8 @@ void f8f8bf16_rowwise_impl_sm89(
// TODO: instead of fixing these values, implement logic alike to
// what is used for SM90+.
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
constexpr auto NumStages = 4;
@ -341,20 +338,6 @@ void f8f8bf16_rowwise_impl_sm89(
cutlass::arch::OpMultiplyAdd>;
constexpr auto NumEVTEpilogueStages = 1;
using ScaleTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape,
WarpShape,
DtypeScale,
AlignmentScale,
NumEVTEpilogueStages>;
using BiasTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape,
WarpShape,
DtypeBias,
AlignmentBias,
NumEVTEpilogueStages>;
using OutputTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape,
@ -365,25 +348,19 @@ void f8f8bf16_rowwise_impl_sm89(
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using XScale =
cutlass::epilogue::threadblock::VisitorColBroadcast<
ScaleTileThreadMap,
DtypeScale,
cute::Stride<cute::_1, cute::_0, int64_t>>;
using XScale = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, DtypeScale,
cute::Stride<cute::_1, cute::_0, int64_t>>;
using XScaleArguments = typename XScale::Arguments;
using WScale =
cutlass::epilogue::threadblock::VisitorRowBroadcast<
ScaleTileThreadMap,
DtypeScale,
cute::Stride<cute::_0, cute::_1, int64_t>>;
using WScale = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, DtypeScale,
cute::Stride<cute::_0, cute::_1, int64_t>>;
using WScaleArguments = typename WScale::Arguments;
using Bias =
cutlass::epilogue::threadblock::VisitorRowBroadcast<
BiasTileThreadMap,
DtypeBias,
cute::Stride<cute::_0, cute::_1, int32_t>>;
using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, DtypeBias,
cute::Stride<cute::_0, cute::_1, int64_t>>;
using BiasArguments = typename Bias::Arguments;
using ApplyXScale = cutlass::epilogue::threadblock::VisitorCompute<
@ -423,8 +400,7 @@ void f8f8bf16_rowwise_impl_sm89(
Output,
EVTApplyBias>;
using EVTKernel =
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
using EVTKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
DtypeA, LayoutInputA, cutlass::ComplexTransform::kNone, AlignmentInputA,
DtypeB, LayoutInputB, cutlass::ComplexTransform::kNone, AlignmentInputB,
DtypeOutput, LayoutOutput, AlignmentOutput,
@ -442,7 +418,7 @@ void f8f8bf16_rowwise_impl_sm89(
NumEVTEpilogueStages
>::GemmKernel;
using Gemm = cutlass::gemm::device::GemmUniversalBase<EVTKernel>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
cutlass::gemm::GemmCoord problem_size(M, N, K);
constexpr auto SplitKFactor = 1;
@ -475,14 +451,13 @@ void f8f8bf16_rowwise_impl_sm89(
{} // ApplyXScale
}, // EVTApplyXScale
w_scale_arguments, // WScale
{}, // ApplyWScale
{} // ApplyWScale
}, // EVTApplyWScale
bias_arguments, // Bias
{} // ApplyBias
}, // EVTApplyBias
output_arguments // Output
}; // EVTOutput
constexpr auto AvailSms = -1;
typename Gemm::Arguments arguments(
cutlass::gemm::GemmUniversalMode::kGemm,
@ -500,8 +475,7 @@ void f8f8bf16_rowwise_impl_sm89(
problem_size.k(), // stride A
problem_size.k(), // stride B
0, // stride C (unused)
0, // stride D (unused)
AvailSms);
0); // stride D (unused)
Gemm gemm;

View File

@ -76,7 +76,7 @@ if(INTERN_BUILD_ATEN_OPS)
file(GLOB_RECURSE all_python "${CMAKE_CURRENT_LIST_DIR}/../torchgen/*.py")
# RowwiseScaled.cu requires sm90a flags
# RowwiseScaled.cu requires sm89/sm90a flags
if(USE_CUDA)
set(ROWWISE_SCALED_MM_FILE "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu")
@ -84,11 +84,17 @@ if(INTERN_BUILD_ATEN_OPS)
torch_cuda_get_nvcc_gencode_flag(EXISTING_ARCH_FLAGS)
# Check NVCC version and existing arch flags
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0 AND
EXISTING_ARCH_FLAGS MATCHES ".*compute_90.*")
set_source_files_properties(${ROWWISE_SCALED_MM_FILE}
PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a")
set(ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "")
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0)
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_86.*")
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_89,code=sm_89")
endif()
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_90.*")
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_90a,code=sm_90a")
endif()
endif()
list(JOIN ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS " " ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS)
set_source_files_properties(${ROWWISE_SCALED_MM_FILE} PROPERTIES COMPILE_FLAGS "${ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS}")
endif()
set(GEN_ROCM_FLAG)

View File

@ -136,7 +136,10 @@ class DistMatrixOpsTest(DTensorTestBase):
@with_comms
@skip_unless_torch_gpu
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "torch._scaled_mm requires H100+")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
)
def test_scaled_mm(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shrd0 = Shard(0)

View File

@ -708,8 +708,8 @@ class AOTInductorTestsTemplate:
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
@unittest.skipIf(
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0),
"FP8 is only supported on H100+ and sm_89 and MI300+ devices",
not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
)
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
@skipIfXpu
@ -756,8 +756,8 @@ class AOTInductorTestsTemplate:
)
@unittest.skipIf(
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0),
"FP8 is only supported on H100+",
not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
)
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
@skipIfXpu
@ -3324,7 +3324,7 @@ class AOTInductorTestsTemplate:
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+ and sm_89 and MI300+ devices",
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
)
def test_runtime_checks_fp8(self):
# cuda only

View File

@ -21,7 +21,7 @@ from torch.utils._triton import has_triton_tma_device
torch.set_float32_matmul_precision("high")
f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices"
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
# define the e4m3/e5m2 constants
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max

View File

@ -1259,7 +1259,7 @@ class TestPrologueFusion(TestCase):
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+ and sm_89 and MI300+ devices",
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
)
def test_low_precision(self):
M = K = N = 128

View File

@ -839,7 +839,7 @@ class TestFlopCounter(TestCase):
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"Does not support fp8 (pre-SM90 hardware on CUDA)",
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
)
def test_scaled_mm(self):
dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn

View File

@ -17,6 +17,7 @@ from torch.quantization._quantized_conversions import (
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import (
SM53OrLater,
SM89OrLater,
_get_torch_cuda_version,
PLATFORM_SUPPORTS_FP8
)
@ -42,10 +43,8 @@ from torch.testing._internal.common_utils import (
)
_IS_SM8X = False
_IS_SM9X = False
if TEST_CUDA:
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
_IS_SM9X = torch.cuda.get_device_capability(0)[0] == 9
# Protects against includes accidentally setting the default dtype
assert torch.get_default_dtype() is torch.float32
@ -213,7 +212,7 @@ class TestMatmulCuda(TestCase):
self.assertEqual(out1_gpu, out2_gpu[0])
f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices"
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
if torch.version.hip:
e4m3_type = torch.float8_e4m3fnuz
@ -538,8 +537,7 @@ class TestFP8MatmulCuda(TestCase):
lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32),
)
@unittest.skipIf(PLATFORM_SUPPORTS_FP8,
"This test is only for devices with compute capability < 8.9")
@unittest.skipIf(PLATFORM_SUPPORTS_FP8, f8_msg)
def test_error_message_fp8_pre_sm89(self, device) -> None:
(k, l, m) = (16, 48, 32)
x = torch.rand((k, l), device=device).to(e4m3_type)
@ -567,7 +565,7 @@ class TestFP8MatmulCuda(TestCase):
self.assertEqual(out_fp8, out_fp8_s)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific")
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific")
@parametrize("use_fast_accum", [True, False])
def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None:
M, K, N = (1024, 512, 2048)
@ -673,7 +671,7 @@ class TestFP8MatmulCuda(TestCase):
)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific")
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific")
@parametrize("base_dtype", [torch.bfloat16])
def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
torch.manual_seed(42)

View File

@ -1047,7 +1047,10 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
self.skipTest('cuSPARSELt not enabled')
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
)
@xfailIfSM89
@parametrize("dense_input_shape", [(256, 128)])
def test_sparse_fp8fp8_mm(self, dense_input_shape, device):
@ -1067,7 +1070,10 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
):
dense_result = torch.mm(A_fp8_sparse, B_fp8)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
)
@xfailIfSM89
def test_sparse_semi_structured_scaled_mm_fp8(self, device) -> None:
(k, l, m) = (32, 64, 32)
@ -1084,7 +1090,10 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
out_fp32_sparse = out_fp8_sparse.to(torch.float32)
torch.testing.assert_close(out_fp32, out_fp32_sparse, rtol=1e-1, atol=1e-1)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
)
@xfailIfSM89
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.float32])
@parametrize("dense_input_shape", [(256, 128)])

View File

@ -31,7 +31,7 @@ from torch.testing._internal.common_device_type import \
toleranceOverride, tol)
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
SM53OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version,
SM53OrLater, SM80OrLater, SM89OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version,
_get_torch_rocm_version,
)
from torch.testing._internal.common_utils import (
@ -16211,7 +16211,7 @@ op_db: list[OpInfo] = [
supports_out=True,
supports_forward_ad=False,
supports_autograd=False,
decorators=[skipCUDAIf(not SM90OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 9.0')],
decorators=[skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')],
skips=(
# Sample inputs isn't really parametrized on dtype
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes',