Triton template IMA reads on B200 (#163460)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163460
Approved by: https://github.com/eqy, https://github.com/alexsamardzic
This commit is contained in:
drisspg
2025-09-22 00:56:14 +00:00
committed by PyTorch MergeBot
parent cf28ab2c88
commit 02da4753f5
2 changed files with 7 additions and 6 deletions

View File

@ -24,7 +24,7 @@ from torch.testing._internal.common_cuda import (
SM80OrLater,
SM89OrLater,
SM90OrLater,
xfailIfSM100OrLater,
SM100OrLater,
xfailIfSM120OrLater,
_get_torch_cuda_version,
PLATFORM_SUPPORTS_FP8,
@ -65,6 +65,8 @@ from torch.testing._internal.common_quantized import (
generate_jagged_offs,
)
from torch._inductor.test_case import TestCase as InductorTestCase
_IS_SM8X = False
if TEST_CUDA:
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
@ -82,7 +84,7 @@ def blas_library_context(backend):
finally:
torch.backends.cuda.preferred_blas_library(prev_backend)
class TestMatmulCuda(TestCase):
class TestMatmulCuda(InductorTestCase):
def setUp(self):
super().setUp()
torch.backends.cuda.matmul.allow_tf32 = False
@ -499,7 +501,6 @@ class TestMatmulCuda(TestCase):
self.grouped_mm_helper(a, blist, gOlist, agradlist, bgradlist, outlist)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater
# TODO(future PR): enable compile for torch._grouped_mm fallback path
@unittest.skipIf(not SM90OrLater, "Grouped gemm with compile supported on SM90")
@parametrize("op", ["2d/2d", "2d/3d", "3d/2d", "3d/3d"])
@ -507,8 +508,8 @@ class TestMatmulCuda(TestCase):
@parametrize("b_row_major", [False, True])
@parametrize("max_autotune", [False, True])
def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune):
torch._dynamo.reset()
if max_autotune and SM100OrLater:
self.skipTest("Triton templates not supported on SM100+ for grouped_mm")
device = "cuda"
dtype_AB = torch.bfloat16
dtype_offset = torch.int32