mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
tc
This commit is contained in:
@ -798,6 +798,10 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_apply_dense(self, dtype) -> None:
|
||||
M, N = 256, 1024
|
||||
x = torch.randn([M, N], dtype=dtype, device="cuda")
|
||||
@ -836,6 +840,10 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_matmuls(self, dtype) -> None:
|
||||
M, N, K = 64, 256, 1024
|
||||
a = torch.randn([M, K], device="cuda", dtype=dtype)
|
||||
@ -871,6 +879,10 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_matmuls_mat_vec(self) -> None:
|
||||
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
|
||||
b = torch.randn([128], device="cuda", dtype=torch.float16)
|
||||
@ -881,6 +893,10 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
"RelWithAssert" in torch.__config__.show(),
|
||||
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
|
||||
)
|
||||
def test_sp24_matmuls_bmm(self) -> None:
|
||||
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
|
||||
b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16)
|
||||
|
Reference in New Issue
Block a user