[ScaledMM] Fix NaNs in test for garbage input data (#144042)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144042
Approved by: https://github.com/janeyx99
This commit is contained in:
drisspg
2024-12-31 15:05:24 -08:00
committed by PyTorch MergeBot
parent b75f32b848
commit 9bf2a9a616

View File

@ -441,6 +441,9 @@ class TestFP8MatmulCuda(TestCase):
x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype)
y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype)
x.normal_()
y.normal_()
x_scale = tensor_to_scale(x, input_dtype).float()
y_scale = tensor_to_scale(y, input_dtype).float()