[Kernel] Marlin_24: Ensure the mma.sp instruction is using the ::ordered_metadata modifier (introduced with PTX 8.5) (#5136)

This commit is contained in:
Alexander Matveev
2024-05-30 22:02:11 -04:00
committed by GitHub
parent b35be5403f
commit 6d21fa1cad

View File

@ -32,7 +32,8 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
float* c = reinterpret_cast<float*>(&frag_c);
if (psel == 0) {
asm volatile(
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16."
"f32 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x0;\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
@ -40,7 +41,8 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
"r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]),
"r"(e[0]));
asm volatile(
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16."
"f32 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x0;\n"
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
@ -49,7 +51,8 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
"r"(e[0]));
} else {
asm volatile(
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16."
"f32 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x1;\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
@ -57,7 +60,8 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
"r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]),
"r"(e[0]));
asm volatile(
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16."
"f32 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x1;\n"
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])