mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Cutlass] Include fp8 headers in aoti cpp wrapper (#155173)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155173 Approved by: https://github.com/desertfire ghstack dependencies: #154829, #154835, #155195
This commit is contained in:
committed by
PyTorch MergeBot
parent
1ed243f01c
commit
3040ca6d0f
@ -495,7 +495,7 @@ class TestCutlassBackend(TestCase):
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@parametrize("dynamic", (False, True))
|
||||
@parametrize("use_aoti", (False,))
|
||||
@parametrize("use_aoti", (False, True))
|
||||
@parametrize("dtype", (torch.float8_e4m3fn,))
|
||||
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_max_autotune_cutlass_backend_fp8_scaled_mm(
|
||||
|
@ -314,7 +314,7 @@ DTYPE_TO_CUTLASS_TYPE = {
|
||||
**DTYPE_TO_CPP,
|
||||
torch.float16: "__half",
|
||||
torch.bfloat16: "__nv_bfloat16",
|
||||
torch.float8_e4m3fn: "cutlass::float_e4m3_t",
|
||||
torch.float8_e4m3fn: "__nv_fp8_e4m3",
|
||||
}
|
||||
|
||||
|
||||
|
@ -12,6 +12,7 @@
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#endif
|
||||
|
||||
namespace torch::aot_inductor {
|
||||
|
Reference in New Issue
Block a user