[ATen][Native][CUDA][SCALED_MM] limit f8f8bf16 rowwise scaled matmul to sm_90 (#145728)

The CUTLASS-based kernel for f8f8bf16 rowwise scaled matmul is specific to Hopper devices only. It is not re-usable on newer devices without modifications. This PR adds a guard for this matmul to be sm_90 specific. Once the kernel is there, the guard may be removed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145728
Approved by: https://github.com/Skylion007, https://github.com/eqy
This commit is contained in:
Aidyn-A
2025-01-30 11:19:56 +00:00
committed by PyTorch MergeBot
parent 6bd19e65b1
commit ffa628169d

View File

@ -708,13 +708,13 @@ void dispatch_fp8_rowwise_kernel_on_sm(
at::Tensor out) {
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
const bool sm89 = properties != nullptr && properties->major == 8 && properties->minor == 9;
const bool sm90OrLater = properties != nullptr && properties->major >= 9;
if (!(sm89 || sm90OrLater)) {
const bool sm9x = properties != nullptr && properties->major == 9;
if (!(sm89 || sm9x)) {
TORCH_CHECK(
false, "Rowwise scaling is not currently supported on your device");
}
if (sm90OrLater) {
if (sm9x) {
dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose<Types...>(XQ, WQ, x_scale, w_scale, bias, out);
} else {
f8f8bf16_rowwise_impl_sm89<Types...>(XQ, WQ, x_scale, w_scale, bias, out);