mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add optimal triton kernel parameters to bsr_dense_mm and scatter_mm for bfloat16 and float32 dtypes (#113553)
As in the title. This PR is a follow-up to PR https://github.com/pytorch/pytorch/pull/112737 to address bfloat16 and float32 dtype cases. The performance increase is as follows (`NVIDIA A100-SXM4-80GB`): - bsr_scatter_mm and bfloat16 - for blocksize 16x16, the average/maximum speed up is about 29/75 %. - for blocksize 32x32, the average/maximum speed up is about 23/58 %. - for blocksize 64x64, the average/maximum speed up is about 27/66 %. - for blocksize 128x128, the average/maximum speed up is about 33/72 %. - bsr_dense_mm and bfloat16 - for blocksize 16x16, the average/maximum speed up is about 47/61 %. - for blocksize 32x32, the average/maximum speed up is about 29/43 %. - for blocksize 64x64, the average/maximum speed up is about 21/41 %. - for blocksize 128x128, the average/maximum speed up is about 12/29 %. - bsr_dense_mm and float32 - for blocksize 16x16, the average/maximum speed up is about 35/49 %. - for blocksize 32x32, the average/maximum speed up is about 2/5 %. - for blocksize 64x64, the average/maximum speed up is about 2/21 %. - for blocksize 128x128, the average/maximum speed up is about 79/84 %. Pull Request resolved: https://github.com/pytorch/pytorch/pull/113553 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
ff82dcd8fa
commit
e1c872e009