mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Build RowwiseScaledMM.cu for SM89 (#145676)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145676 Approved by: https://github.com/drisspg, https://github.com/malfet, https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
f40e013787
commit
2b00d211f0
@ -315,9 +315,6 @@ void f8f8bf16_rowwise_impl_sm89(
|
||||
using LayoutInputB = cutlass::layout::ColumnMajor;
|
||||
constexpr int AlignmentInputB = 16 / sizeof(DtypeB);
|
||||
|
||||
constexpr int AlignmentScale = 16 / sizeof(DtypeScale);
|
||||
constexpr int AlignmentBias = 16 / sizeof(DtypeBias);
|
||||
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput);
|
||||
|
||||
@ -330,8 +327,8 @@ void f8f8bf16_rowwise_impl_sm89(
|
||||
|
||||
// TODO: instead of fixing these values, implement logic alike to
|
||||
// what is used for SM90+.
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
constexpr auto NumStages = 4;
|
||||
|
||||
@ -341,20 +338,6 @@ void f8f8bf16_rowwise_impl_sm89(
|
||||
cutlass::arch::OpMultiplyAdd>;
|
||||
constexpr auto NumEVTEpilogueStages = 1;
|
||||
|
||||
using ScaleTileThreadMap =
|
||||
cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
DtypeScale,
|
||||
AlignmentScale,
|
||||
NumEVTEpilogueStages>;
|
||||
using BiasTileThreadMap =
|
||||
cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
DtypeBias,
|
||||
AlignmentBias,
|
||||
NumEVTEpilogueStages>;
|
||||
using OutputTileThreadMap =
|
||||
cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
||||
ThreadblockShape,
|
||||
@ -365,25 +348,19 @@ void f8f8bf16_rowwise_impl_sm89(
|
||||
|
||||
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||
|
||||
using XScale =
|
||||
cutlass::epilogue::threadblock::VisitorColBroadcast<
|
||||
ScaleTileThreadMap,
|
||||
DtypeScale,
|
||||
cute::Stride<cute::_1, cute::_0, int64_t>>;
|
||||
using XScale = cutlass::epilogue::threadblock::VisitorColBroadcast<
|
||||
OutputTileThreadMap, DtypeScale,
|
||||
cute::Stride<cute::_1, cute::_0, int64_t>>;
|
||||
using XScaleArguments = typename XScale::Arguments;
|
||||
|
||||
using WScale =
|
||||
cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
ScaleTileThreadMap,
|
||||
DtypeScale,
|
||||
cute::Stride<cute::_0, cute::_1, int64_t>>;
|
||||
using WScale = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
OutputTileThreadMap, DtypeScale,
|
||||
cute::Stride<cute::_0, cute::_1, int64_t>>;
|
||||
using WScaleArguments = typename WScale::Arguments;
|
||||
|
||||
using Bias =
|
||||
cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
BiasTileThreadMap,
|
||||
DtypeBias,
|
||||
cute::Stride<cute::_0, cute::_1, int32_t>>;
|
||||
using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
OutputTileThreadMap, DtypeBias,
|
||||
cute::Stride<cute::_0, cute::_1, int64_t>>;
|
||||
using BiasArguments = typename Bias::Arguments;
|
||||
|
||||
using ApplyXScale = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
@ -423,8 +400,7 @@ void f8f8bf16_rowwise_impl_sm89(
|
||||
Output,
|
||||
EVTApplyBias>;
|
||||
|
||||
using EVTKernel =
|
||||
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
||||
using EVTKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
||||
DtypeA, LayoutInputA, cutlass::ComplexTransform::kNone, AlignmentInputA,
|
||||
DtypeB, LayoutInputB, cutlass::ComplexTransform::kNone, AlignmentInputB,
|
||||
DtypeOutput, LayoutOutput, AlignmentOutput,
|
||||
@ -442,7 +418,7 @@ void f8f8bf16_rowwise_impl_sm89(
|
||||
NumEVTEpilogueStages
|
||||
>::GemmKernel;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalBase<EVTKernel>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size(M, N, K);
|
||||
constexpr auto SplitKFactor = 1;
|
||||
@ -475,14 +451,13 @@ void f8f8bf16_rowwise_impl_sm89(
|
||||
{} // ApplyXScale
|
||||
}, // EVTApplyXScale
|
||||
w_scale_arguments, // WScale
|
||||
{}, // ApplyWScale
|
||||
{} // ApplyWScale
|
||||
}, // EVTApplyWScale
|
||||
bias_arguments, // Bias
|
||||
{} // ApplyBias
|
||||
}, // EVTApplyBias
|
||||
output_arguments // Output
|
||||
}; // EVTOutput
|
||||
constexpr auto AvailSms = -1;
|
||||
|
||||
typename Gemm::Arguments arguments(
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
@ -500,8 +475,7 @@ void f8f8bf16_rowwise_impl_sm89(
|
||||
problem_size.k(), // stride A
|
||||
problem_size.k(), // stride B
|
||||
0, // stride C (unused)
|
||||
0, // stride D (unused)
|
||||
AvailSms);
|
||||
0); // stride D (unused)
|
||||
|
||||
Gemm gemm;
|
||||
|
||||
|
@ -76,7 +76,7 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
|
||||
file(GLOB_RECURSE all_python "${CMAKE_CURRENT_LIST_DIR}/../torchgen/*.py")
|
||||
|
||||
# RowwiseScaled.cu requires sm90a flags
|
||||
# RowwiseScaled.cu requires sm89/sm90a flags
|
||||
if(USE_CUDA)
|
||||
set(ROWWISE_SCALED_MM_FILE "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu")
|
||||
|
||||
@ -84,11 +84,17 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
torch_cuda_get_nvcc_gencode_flag(EXISTING_ARCH_FLAGS)
|
||||
|
||||
# Check NVCC version and existing arch flags
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0 AND
|
||||
EXISTING_ARCH_FLAGS MATCHES ".*compute_90.*")
|
||||
set_source_files_properties(${ROWWISE_SCALED_MM_FILE}
|
||||
PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a")
|
||||
set(ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "")
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0)
|
||||
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_86.*")
|
||||
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_89,code=sm_89")
|
||||
endif()
|
||||
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_90.*")
|
||||
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_90a,code=sm_90a")
|
||||
endif()
|
||||
endif()
|
||||
list(JOIN ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS " " ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS)
|
||||
set_source_files_properties(${ROWWISE_SCALED_MM_FILE} PROPERTIES COMPILE_FLAGS "${ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS}")
|
||||
endif()
|
||||
|
||||
set(GEN_ROCM_FLAG)
|
||||
|
@ -136,7 +136,10 @@ class DistMatrixOpsTest(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
@skip_unless_torch_gpu
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "torch._scaled_mm requires H100+")
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FP8,
|
||||
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
||||
)
|
||||
def test_scaled_mm(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
shrd0 = Shard(0)
|
||||
|
@ -708,8 +708,8 @@ class AOTInductorTestsTemplate:
|
||||
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0),
|
||||
"FP8 is only supported on H100+ and sm_89 and MI300+ devices",
|
||||
not PLATFORM_SUPPORTS_FP8,
|
||||
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
||||
)
|
||||
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
|
||||
@skipIfXpu
|
||||
@ -756,8 +756,8 @@ class AOTInductorTestsTemplate:
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0),
|
||||
"FP8 is only supported on H100+",
|
||||
not PLATFORM_SUPPORTS_FP8,
|
||||
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
||||
)
|
||||
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
|
||||
@skipIfXpu
|
||||
@ -3324,7 +3324,7 @@ class AOTInductorTestsTemplate:
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FP8,
|
||||
"FP8 is only supported on H100+ and sm_89 and MI300+ devices",
|
||||
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
||||
)
|
||||
def test_runtime_checks_fp8(self):
|
||||
# cuda only
|
||||
|
@ -21,7 +21,7 @@ from torch.utils._triton import has_triton_tma_device
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
|
||||
f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices"
|
||||
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
|
||||
|
||||
# define the e4m3/e5m2 constants
|
||||
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
@ -1259,7 +1259,7 @@ class TestPrologueFusion(TestCase):
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FP8,
|
||||
"FP8 is only supported on H100+ and sm_89 and MI300+ devices",
|
||||
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
||||
)
|
||||
def test_low_precision(self):
|
||||
M = K = N = 128
|
||||
|
@ -839,7 +839,7 @@ class TestFlopCounter(TestCase):
|
||||
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FP8,
|
||||
"Does not support fp8 (pre-SM90 hardware on CUDA)",
|
||||
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
||||
)
|
||||
def test_scaled_mm(self):
|
||||
dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn
|
||||
|
@ -17,6 +17,7 @@ from torch.quantization._quantized_conversions import (
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_cuda import (
|
||||
SM53OrLater,
|
||||
SM89OrLater,
|
||||
_get_torch_cuda_version,
|
||||
PLATFORM_SUPPORTS_FP8
|
||||
)
|
||||
@ -42,10 +43,8 @@ from torch.testing._internal.common_utils import (
|
||||
)
|
||||
|
||||
_IS_SM8X = False
|
||||
_IS_SM9X = False
|
||||
if TEST_CUDA:
|
||||
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
|
||||
_IS_SM9X = torch.cuda.get_device_capability(0)[0] == 9
|
||||
|
||||
# Protects against includes accidentally setting the default dtype
|
||||
assert torch.get_default_dtype() is torch.float32
|
||||
@ -213,7 +212,7 @@ class TestMatmulCuda(TestCase):
|
||||
self.assertEqual(out1_gpu, out2_gpu[0])
|
||||
|
||||
|
||||
f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices"
|
||||
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
|
||||
|
||||
if torch.version.hip:
|
||||
e4m3_type = torch.float8_e4m3fnuz
|
||||
@ -538,8 +537,7 @@ class TestFP8MatmulCuda(TestCase):
|
||||
lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32),
|
||||
)
|
||||
|
||||
@unittest.skipIf(PLATFORM_SUPPORTS_FP8,
|
||||
"This test is only for devices with compute capability < 8.9")
|
||||
@unittest.skipIf(PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
def test_error_message_fp8_pre_sm89(self, device) -> None:
|
||||
(k, l, m) = (16, 48, 32)
|
||||
x = torch.rand((k, l), device=device).to(e4m3_type)
|
||||
@ -567,7 +565,7 @@ class TestFP8MatmulCuda(TestCase):
|
||||
self.assertEqual(out_fp8, out_fp8_s)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific")
|
||||
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific")
|
||||
@parametrize("use_fast_accum", [True, False])
|
||||
def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None:
|
||||
M, K, N = (1024, 512, 2048)
|
||||
@ -673,7 +671,7 @@ class TestFP8MatmulCuda(TestCase):
|
||||
)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific")
|
||||
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific")
|
||||
@parametrize("base_dtype", [torch.bfloat16])
|
||||
def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
|
||||
torch.manual_seed(42)
|
||||
|
@ -1047,7 +1047,10 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
|
||||
self.skipTest('cuSPARSELt not enabled')
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices")
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FP8,
|
||||
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
||||
)
|
||||
@xfailIfSM89
|
||||
@parametrize("dense_input_shape", [(256, 128)])
|
||||
def test_sparse_fp8fp8_mm(self, dense_input_shape, device):
|
||||
@ -1067,7 +1070,10 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
):
|
||||
dense_result = torch.mm(A_fp8_sparse, B_fp8)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices")
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FP8,
|
||||
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
||||
)
|
||||
@xfailIfSM89
|
||||
def test_sparse_semi_structured_scaled_mm_fp8(self, device) -> None:
|
||||
(k, l, m) = (32, 64, 32)
|
||||
@ -1084,7 +1090,10 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
out_fp32_sparse = out_fp8_sparse.to(torch.float32)
|
||||
torch.testing.assert_close(out_fp32, out_fp32_sparse, rtol=1e-1, atol=1e-1)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices")
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FP8,
|
||||
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
||||
)
|
||||
@xfailIfSM89
|
||||
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
@parametrize("dense_input_shape", [(256, 128)])
|
||||
|
@ -31,7 +31,7 @@ from torch.testing._internal.common_device_type import \
|
||||
toleranceOverride, tol)
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
||||
SM53OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version,
|
||||
SM53OrLater, SM80OrLater, SM89OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version,
|
||||
_get_torch_rocm_version,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -16211,7 +16211,7 @@ op_db: list[OpInfo] = [
|
||||
supports_out=True,
|
||||
supports_forward_ad=False,
|
||||
supports_autograd=False,
|
||||
decorators=[skipCUDAIf(not SM90OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 9.0')],
|
||||
decorators=[skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')],
|
||||
skips=(
|
||||
# Sample inputs isn't really parametrized on dtype
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes',
|
||||
|
Reference in New Issue
Block a user