mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][cutlass backend] Fix subproc addmm tests (#160295)
Differential Revision: [D79977421](https://our.internmc.facebook.com/intern/diff/D79977421/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160295 Approved by: https://github.com/jingsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
0d40ff3b49
commit
b90feeac86
@ -294,20 +294,19 @@ class TestCutlassBackend(TestCase):
|
||||
Y = torch.mm(a, b)
|
||||
torch.testing.assert_close(Y_compiled, Y)
|
||||
|
||||
@unittest.skipIf(
|
||||
True, "FIXME: Disabled temporarily since IMA or crashing in subprocess"
|
||||
)
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_cutlass_backend_subproc_addmm(self, shape_combo):
|
||||
@parametrize("dtype", (torch.float16, torch.bfloat16))
|
||||
def test_cutlass_backend_subproc_addmm(self, dtype):
|
||||
"""
|
||||
Test autotune_in_subproc works for addmm.
|
||||
"""
|
||||
|
||||
M, N, K = 4096, 2048, 25728
|
||||
dtype = torch.float16
|
||||
|
||||
a = torch.randn(M, K).cuda().half()
|
||||
b = torch.randn(N, K).cuda().half().t()
|
||||
a = torch.randn(M, K, dtype=dtype).cuda()
|
||||
b = torch.randn(N, K, dtype=dtype).cuda().t()
|
||||
|
||||
x_shapes = [
|
||||
(M, N),
|
||||
@ -329,7 +328,10 @@ class TestCutlassBackend(TestCase):
|
||||
}
|
||||
):
|
||||
for x_shape in x_shapes:
|
||||
x = torch.randn(x_shape).cuda().half()
|
||||
torch._dynamo.reset()
|
||||
clear_caches()
|
||||
|
||||
x = torch.randn(x_shape).cuda().to(dtype)
|
||||
Y_compiled = torch.compile(torch.addmm)(x, a, b, alpha=alpha, beta=beta)
|
||||
Y = torch.addmm(x, a, b, alpha=alpha, beta=beta)
|
||||
torch.testing.assert_close(Y_compiled, Y)
|
||||
|
Reference in New Issue
Block a user