[BE][cutlass backend] Fix subproc addmm tests (#160295)

Differential Revision: [D79977421](https://our.internmc.facebook.com/intern/diff/D79977421/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160295
Approved by: https://github.com/jingsh
This commit is contained in:
henrylhtsang
2025-08-10 20:37:44 -07:00
committed by PyTorch MergeBot
parent 0d40ff3b49
commit b90feeac86

View File

@ -294,20 +294,19 @@ class TestCutlassBackend(TestCase):
Y = torch.mm(a, b)
torch.testing.assert_close(Y_compiled, Y)
@unittest.skipIf(
True, "FIXME: Disabled temporarily since IMA or crashing in subprocess"
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_cutlass_backend_subproc_addmm(self, shape_combo):
@parametrize("dtype", (torch.float16, torch.bfloat16))
def test_cutlass_backend_subproc_addmm(self, dtype):
"""
Test autotune_in_subproc works for addmm.
"""
M, N, K = 4096, 2048, 25728
dtype = torch.float16
a = torch.randn(M, K).cuda().half()
b = torch.randn(N, K).cuda().half().t()
a = torch.randn(M, K, dtype=dtype).cuda()
b = torch.randn(N, K, dtype=dtype).cuda().t()
x_shapes = [
(M, N),
@ -329,7 +328,10 @@ class TestCutlassBackend(TestCase):
}
):
for x_shape in x_shapes:
x = torch.randn(x_shape).cuda().half()
torch._dynamo.reset()
clear_caches()
x = torch.randn(x_shape).cuda().to(dtype)
Y_compiled = torch.compile(torch.addmm)(x, a, b, alpha=alpha, beta=beta)
Y = torch.addmm(x, a, b, alpha=alpha, beta=beta)
torch.testing.assert_close(Y_compiled, Y)