[tests] Reduce sizes of unnecessarily large tensors to reduce OOM flakes (#158456)

Downsizes several tensors that were massively oversized to test the problem at hand, to reduce test flaking.

Fixes #126867

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158456
Approved by: https://github.com/desertfire
This commit is contained in:
Benjamin Glass
2025-07-22 23:41:44 +00:00
committed by PyTorch MergeBot
parent 6100ed457c
commit cab96b5879

View File

@ -815,9 +815,9 @@ class TestMaxAutotune(TestCase):
Check https://github.com/pytorch/pytorch/issues/125437 for more details.
"""
x = rand_strided(
(50257, 32768), (1, 50304), dtype=torch.bfloat16, device=GPU_TYPE
(50257, 2048), (1, 50304), dtype=torch.bfloat16, device=GPU_TYPE
)
y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device=GPU_TYPE)
y = rand_strided((2048, 768), (768, 1), dtype=torch.bfloat16, device=GPU_TYPE)
@torch.compile(mode="max-autotune")
def f(x, y):
@ -830,9 +830,9 @@ class TestMaxAutotune(TestCase):
def test_non_contiguous_input_addmm(self):
b = torch.randn((768), dtype=torch.bfloat16, device=GPU_TYPE)
x = rand_strided(
(50257, 32768), (1, 50304), dtype=torch.bfloat16, device=GPU_TYPE
(50257, 2048), (1, 50304), dtype=torch.bfloat16, device=GPU_TYPE
)
y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device=GPU_TYPE)
y = rand_strided((2048, 768), (768, 1), dtype=torch.bfloat16, device=GPU_TYPE)
@torch.compile(mode="max-autotune")
def f(x, y):
@ -844,10 +844,10 @@ class TestMaxAutotune(TestCase):
def test_non_contiguous_input_bmm(self):
x = rand_strided(
(1, 50257, 32768), (0, 1, 50304), dtype=torch.bfloat16, device=GPU_TYPE
(1, 50257, 2048), (0, 1, 50304), dtype=torch.bfloat16, device=GPU_TYPE
)
y = rand_strided(
(1, 32768, 768), (0, 768, 1), dtype=torch.bfloat16, device=GPU_TYPE
(1, 2048, 768), (0, 768, 1), dtype=torch.bfloat16, device=GPU_TYPE
)
@torch.compile(mode="max-autotune")
@ -861,16 +861,12 @@ class TestMaxAutotune(TestCase):
# TODO: fix accuracy failure of the triton template on XPU.
# and enable this test case.
@skipIfXpu
@unittest.skipIf(
os.getenv("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1",
"OOM when running with TORCHINDUCTOR_CPP_WRAPPER https://github.com/pytorch/pytorch/issues/126867",
)
def test_non_contiguous_input_mm_plus_mm(self):
x1 = rand_strided((50257, 32768), (1, 50304), device=GPU_TYPE)
y1 = rand_strided((32768, 768), (768, 1), device=GPU_TYPE)
x1 = rand_strided((50257, 2048), (1, 50304), device=GPU_TYPE)
y1 = rand_strided((2048, 768), (768, 1), device=GPU_TYPE)
x2 = rand_strided((50257, 32768), (1, 50304), device=GPU_TYPE)
y2 = rand_strided((32768, 768), (768, 1), device=GPU_TYPE)
x2 = rand_strided((50257, 2048), (1, 50304), device=GPU_TYPE)
y2 = rand_strided((2048, 768), (768, 1), device=GPU_TYPE)
@torch.compile(mode="max-autotune")
def f(x1, y1, x2, y2):