mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Triton template IMA reads on B200 (#163460)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163460 Approved by: https://github.com/eqy, https://github.com/alexsamardzic
This commit is contained in:
committed by
PyTorch MergeBot
parent
cf28ab2c88
commit
02da4753f5
@ -24,7 +24,7 @@ from torch.testing._internal.common_cuda import (
|
||||
SM80OrLater,
|
||||
SM89OrLater,
|
||||
SM90OrLater,
|
||||
xfailIfSM100OrLater,
|
||||
SM100OrLater,
|
||||
xfailIfSM120OrLater,
|
||||
_get_torch_cuda_version,
|
||||
PLATFORM_SUPPORTS_FP8,
|
||||
@ -65,6 +65,8 @@ from torch.testing._internal.common_quantized import (
|
||||
generate_jagged_offs,
|
||||
)
|
||||
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
|
||||
_IS_SM8X = False
|
||||
if TEST_CUDA:
|
||||
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
|
||||
@ -82,7 +84,7 @@ def blas_library_context(backend):
|
||||
finally:
|
||||
torch.backends.cuda.preferred_blas_library(prev_backend)
|
||||
|
||||
class TestMatmulCuda(TestCase):
|
||||
class TestMatmulCuda(InductorTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
@ -499,7 +501,6 @@ class TestMatmulCuda(TestCase):
|
||||
self.grouped_mm_helper(a, blist, gOlist, agradlist, bgradlist, outlist)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@xfailIfSM100OrLater
|
||||
# TODO(future PR): enable compile for torch._grouped_mm fallback path
|
||||
@unittest.skipIf(not SM90OrLater, "Grouped gemm with compile supported on SM90")
|
||||
@parametrize("op", ["2d/2d", "2d/3d", "3d/2d", "3d/3d"])
|
||||
@ -507,8 +508,8 @@ class TestMatmulCuda(TestCase):
|
||||
@parametrize("b_row_major", [False, True])
|
||||
@parametrize("max_autotune", [False, True])
|
||||
def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune):
|
||||
torch._dynamo.reset()
|
||||
|
||||
if max_autotune and SM100OrLater:
|
||||
self.skipTest("Triton templates not supported on SM100+ for grouped_mm")
|
||||
device = "cuda"
|
||||
dtype_AB = torch.bfloat16
|
||||
dtype_offset = torch.int32
|
||||
|
Reference in New Issue
Block a user