mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[ATen][Native][CUDA][SCALED_MM] limit f8f8bf16 rowwise scaled matmul to sm_90 (#145728)
The CUTLASS-based kernel for f8f8bf16 rowwise scaled matmul is specific to Hopper devices only. It is not re-usable on newer devices without modifications. This PR adds a guard for this matmul to be sm_90 specific. Once the kernel is there, the guard may be removed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145728 Approved by: https://github.com/Skylion007, https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
6bd19e65b1
commit
ffa628169d
@ -708,13 +708,13 @@ void dispatch_fp8_rowwise_kernel_on_sm(
|
||||
at::Tensor out) {
|
||||
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
|
||||
const bool sm89 = properties != nullptr && properties->major == 8 && properties->minor == 9;
|
||||
const bool sm90OrLater = properties != nullptr && properties->major >= 9;
|
||||
if (!(sm89 || sm90OrLater)) {
|
||||
const bool sm9x = properties != nullptr && properties->major == 9;
|
||||
if (!(sm89 || sm9x)) {
|
||||
TORCH_CHECK(
|
||||
false, "Rowwise scaling is not currently supported on your device");
|
||||
}
|
||||
|
||||
if (sm90OrLater) {
|
||||
if (sm9x) {
|
||||
dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose<Types...>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
} else {
|
||||
f8f8bf16_rowwise_impl_sm89<Types...>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
|
Reference in New Issue
Block a user