bsr_dense_bmm(): enable more precise float32 support with float64 accumulators (#100882)

Float64 is there in Triton! This PR increases precision for float32 inputs with float64 accumulation dtype.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100882
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Nikita Vedeneev
2023-05-11 07:41:38 +00:00
committed by PyTorch MergeBot
parent 979f55d3bc
commit dd2c22f4bb
2 changed files with 14 additions and 4 deletions

View File

@ -3347,8 +3347,8 @@ class TestSparseCompressedTritonKernels(TestCase):
@parametrize("index_dtype", [torch.int32, torch.int64])
@onlyCUDA
@skipIfRocm
@dtypes(torch.half, torch.bfloat16)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [])
@dtypes(torch.half, torch.bfloat16, torch.float)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
from functools import partial

View File

@ -130,6 +130,8 @@ if _has_triton():
output_row_block_stride,
output_col_block_stride,
# output epilogue
acc_dtype: tl.constexpr,
allow_tf32: tl.constexpr,
GROUP_SIZE_ROW: tl.constexpr,
):
batch_pid = tl.program_id(axis=2)
@ -195,7 +197,7 @@ if _has_triton():
+ col_indices_stride * nnz_offset
)
output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), tl.float32)
output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), dtype=acc_dtype)
for _ in range(row_nnz):
values_block = tl.load(values_block_ptrs)
@ -205,7 +207,7 @@ if _has_triton():
dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx)
# do block mm
output_acc_block += tl.dot(values_block, dense_block)
output_acc_block += tl.dot(values_block, dense_block, allow_tf32=allow_tf32)
# move val/col_index ptrs to the next block in the row
values_block_ptrs += values_nnz_stride
@ -234,11 +236,19 @@ if _has_triton():
dense: (0, -3, None),
output: (0, -3, -4)
}
if values.dtype in (torch.half, torch.bfloat16):
acc_dtype = tl.float32
allow_tf32 = True
else:
acc_dtype = tl.float64
allow_tf32 = False
def kernel(grid, *sliced_tensors):
_bsr_strided_dense_rowspace_kernel[grid](
*blocksize,
*ptr_stride_extractor(*sliced_tensors),
acc_dtype=acc_dtype,
allow_tf32=allow_tf32,
GROUP_SIZE_ROW=4,
num_stages=1,
num_warps=4