mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ScaledMM] Fix NaNs in test for garbage input data (#144042)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144042 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
b75f32b848
commit
9bf2a9a616
@ -441,6 +441,9 @@ class TestFP8MatmulCuda(TestCase):
|
||||
x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype)
|
||||
y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype)
|
||||
|
||||
x.normal_()
|
||||
y.normal_()
|
||||
|
||||
x_scale = tensor_to_scale(x, input_dtype).float()
|
||||
y_scale = tensor_to_scale(y, input_dtype).float()
|
||||
|
||||
|
Reference in New Issue
Block a user