From 9bf2a9a61679f0dfff90f3c52c70e9ed2961da09 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 31 Dec 2024 15:05:24 -0800 Subject: [PATCH] [ScaledMM] Fix NaNs in test for garbage input data (#144042) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144042 Approved by: https://github.com/janeyx99 --- test/test_matmul_cuda.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 9594208858f7..abd1b754dffa 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -441,6 +441,9 @@ class TestFP8MatmulCuda(TestCase): x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype) y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype) + x.normal_() + y.normal_() + x_scale = tensor_to_scale(x, input_dtype).float() y_scale = tensor_to_scale(y, input_dtype).float()