Add optimal triton kernel parameters to bsr_dense_mm and scatter_mm for bfloat16 and float32 dtypes (#113553)

As in the title.

This PR is a follow-up to PR https://github.com/pytorch/pytorch/pull/112737 to address bfloat16 and float32 dtype cases. The performance increase is as follows (`NVIDIA A100-SXM4-80GB`):

- bsr_scatter_mm and bfloat16
  - for blocksize 16x16, the average/maximum speed up is about 29/75 %.
  - for blocksize 32x32, the average/maximum speed up is about 23/58 %.
  - for blocksize 64x64, the average/maximum speed up is about 27/66 %.
  - for blocksize 128x128, the average/maximum speed up is about 33/72 %.
- bsr_dense_mm and bfloat16
  - for blocksize 16x16, the average/maximum speed up is about 47/61 %.
  - for blocksize 32x32, the average/maximum speed up is about 29/43 %.
  - for blocksize 64x64, the average/maximum speed up is about 21/41 %.
  - for blocksize 128x128, the average/maximum speed up is about 12/29 %.
- bsr_dense_mm and  float32
  - for blocksize 16x16, the average/maximum speed up is about 35/49 %.
  - for blocksize 32x32, the average/maximum speed up is about 2/5 %.
  - for blocksize 64x64, the average/maximum speed up is about 2/21 %.
  - for blocksize 128x128, the average/maximum speed up is about 79/84 %.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113553
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Pearu Peterson
2023-11-13 09:41:19 +00:00
committed by PyTorch MergeBot
parent ff82dcd8fa
commit e1c872e009
2 changed files with 1486 additions and 247 deletions

File diff suppressed because it is too large Load Diff