mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
bsr_dense_bmm(): enable more precise float32 support with float64 accumulators (#100882)
Float64 is there in Triton! This PR increases precision for float32 inputs with float64 accumulation dtype. Pull Request resolved: https://github.com/pytorch/pytorch/pull/100882 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
979f55d3bc
commit
dd2c22f4bb
@ -3347,8 +3347,8 @@ class TestSparseCompressedTritonKernels(TestCase):
|
||||
@parametrize("index_dtype", [torch.int32, torch.int64])
|
||||
@onlyCUDA
|
||||
@skipIfRocm
|
||||
@dtypes(torch.half, torch.bfloat16)
|
||||
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [])
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float)
|
||||
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
|
||||
def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
|
||||
from functools import partial
|
||||
|
@ -130,6 +130,8 @@ if _has_triton():
|
||||
output_row_block_stride,
|
||||
output_col_block_stride,
|
||||
# output epilogue
|
||||
acc_dtype: tl.constexpr,
|
||||
allow_tf32: tl.constexpr,
|
||||
GROUP_SIZE_ROW: tl.constexpr,
|
||||
):
|
||||
batch_pid = tl.program_id(axis=2)
|
||||
@ -195,7 +197,7 @@ if _has_triton():
|
||||
+ col_indices_stride * nnz_offset
|
||||
)
|
||||
|
||||
output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), tl.float32)
|
||||
output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), dtype=acc_dtype)
|
||||
for _ in range(row_nnz):
|
||||
values_block = tl.load(values_block_ptrs)
|
||||
|
||||
@ -205,7 +207,7 @@ if _has_triton():
|
||||
dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx)
|
||||
|
||||
# do block mm
|
||||
output_acc_block += tl.dot(values_block, dense_block)
|
||||
output_acc_block += tl.dot(values_block, dense_block, allow_tf32=allow_tf32)
|
||||
|
||||
# move val/col_index ptrs to the next block in the row
|
||||
values_block_ptrs += values_nnz_stride
|
||||
@ -234,11 +236,19 @@ if _has_triton():
|
||||
dense: (0, -3, None),
|
||||
output: (0, -3, -4)
|
||||
}
|
||||
if values.dtype in (torch.half, torch.bfloat16):
|
||||
acc_dtype = tl.float32
|
||||
allow_tf32 = True
|
||||
else:
|
||||
acc_dtype = tl.float64
|
||||
allow_tf32 = False
|
||||
|
||||
def kernel(grid, *sliced_tensors):
|
||||
_bsr_strided_dense_rowspace_kernel[grid](
|
||||
*blocksize,
|
||||
*ptr_stride_extractor(*sliced_tensors),
|
||||
acc_dtype=acc_dtype,
|
||||
allow_tf32=allow_tf32,
|
||||
GROUP_SIZE_ROW=4,
|
||||
num_stages=1,
|
||||
num_warps=4
|
||||
|
Reference in New Issue
Block a user