mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary:
In this PR we integrate the [FBGEMM AMD FP8 rowwise scaling grouped GEMM kernel](https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped) to add support for the `_scaled_grouped_mm` API on AMD. `_scaled_grouped_mm` is [currently supported on Nvidia](9faef3d17c/aten/src/ATen/native/cuda/Blas.cpp (L1614)
), this PR aims to bring parity to AMD. Related: [[RFC]: PyTorch Low-Precision GEMMs Public API](https://github.com/pytorch/pytorch/issues/157950#top) #157950.
The kernel is developed using the Composable Kernel framework. Only MI300X is currently supported. In the near future we plan to add support for MI350X as well. For data types we support FP8 e3m4.
The kernel support will be gated with the `USE_FBGEMM_GENAI` flag. We hope to enable this by default for relevant AMD builds.
Note we also update submodule `third_party/fbgemm` to 0adf62831 for the required updates from fbgemm.
Test Plan:
**Hipify & build**
```
python tools/amd_build/build_amd.py
USE_FBGEMM_GENAI=1 python setup.py develop
```
**Unit tests**
```
python test/test_matmul_cuda.py -- TestFP8MatmulCUDA
Ran 488 tests in 32.969s
OK (skipped=454)
```
**Performance Sample**
| G | M | N | K | Runtime Ms | GB/S | TFLOPS |
| -- | -- | -- | -- | -- | -- | -- |
| 128 | 1 | 2048 | 5120 | 0.37| 3590 | 7.17 |
| 128 | 64 | 2048 | 5120 | 0.51| 2792 | 338.34 |
| 128 | 128 | 2048 | 5120 | 0.66| 2272 | 522.72 |
| 128 | 1 | 5120 | 1024 | 0.21| 3224 | 6.43 |
| 128 | 64 | 5120 | 1024 | 0.29| 2590 | 291.40 |
| 128 | 128 | 5120 | 1024 | 0.40| 2165 | 434.76 |
| 128 | 1 | 4096 | 4096 | 0.69| 3126 | 6.25 |
| 128 | 64 | 4096 | 4096 | 0.85| 2655 | 324.66 |
| 128 | 128 | 4096 | 4096 | 1.10| 2142 | 501.40 |
| 128 | 1 | 8192 | 8192 | 2.45| 3508 | 7.01 |
| 128 | 64 | 8192 | 8192 | 3.27| 2692 | 336.74 |
| 128 | 128 | 8192 | 8192 | 4.04| 2224 | 543.76 |
| 16 | 1 | 2048 | 5120 | 0.04| 3928 | 7.85 |
| 16 | 64 | 2048 | 5120 | 0.05| 3295 | 399.29 |
| 16 | 128 | 2048 | 5120 | 0.07| 2558 | 588.69 |
| 16 | 1 | 5120 | 1024 | 0.03| 3119 | 6.23 |
| 16 | 64 | 5120 | 1024 | 0.03| 2849 | 320.62 |
| 16 | 128 | 5120 | 1024 | 0.05| 2013 | 404.11 |
| 16 | 1 | 4096 | 4096 | 0.06| 4512 | 9.02 |
| 16 | 64 | 4096 | 4096 | 0.09| 3124 | 381.95 |
| 16 | 128 | 4096 | 4096 | 0.13| 2340 | 547.67 |
| 16 | 1 | 8192 | 8192 | 0.32| 3374 | 6.75 |
| 16 | 64 | 8192 | 8192 | 0.42| 2593 | 324.28 |
| 16 | 128 | 8192 | 8192 | 0.53| 2120 | 518.36 |
- Using ROCm 6.4.1
- Collected through `triton.testing.do_bench_cudagraph`
**Binary size with gfx942 arch**
Before: 116103856 Jul 23 14:12 build/lib/libtorch_hip.so
After: 118860960 Jul 23 14:29 build/lib/libtorch_hip.so
The difference is 2757104 bytes (~2.6 MiB).
Reviewers: @drisspg @ngimel @jwfromm @jeffdaily
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159075
Approved by: https://github.com/drisspg
This folder contains vendored copies of third-party libraries that we use.