Files
pytorch/test/test_torchfuzz_repros.py
2025-10-10 22:18:11 +00:00

553 lines
23 KiB
Python

# Owner(s): ["module: tests"]
"""
Fuzzer-discovered eager/compile divergence test cases.
All tests are marked as xfail since they represent known compilation bugs.
IF YOU ARE HERE YOU LIKELY DIDN'T DO ANYTHING WRONG. In fact, you probably did something right!
All of these tests are associated with bugs the fuzzer found. If one of these tests starts failing due to your PR,
it actually means your PR fixed the bug! Feel free to delete the test and close out the issue linked from the test.
"""
import pytest
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
class TestFuzzerCompileIssues(TestCase):
"""Test cases for fuzzer-discovered eager/compile divergence issues."""
def setUp(self):
"""Configure common test settings."""
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._inductor.config.emulate_precision_casts = True
@pytest.mark.xfail(reason="Issue #164484")
def test_fuzzer_issue_164484(self):
torch.manual_seed(9157)
def foo(arg0, arg1, arg2, arg3):
var_node_2 = torch.full((14, 16), 1.158473253250122, dtype=torch.float32)
var_node_1 = torch.nn.functional.relu(var_node_2)
var_node_6 = torch.full((14, 1), -0.94140625, dtype=torch.bfloat16)
var_node_7 = arg0 # size=(1, 16), stride=(16, 1), dtype=bfloat16
var_node_5 = torch.matmul(
var_node_6.to(torch.bfloat16), var_node_7.to(torch.bfloat16)
)
var_node_9 = torch.full((16,), 0.76953125, dtype=torch.bfloat16)
var_node_8 = torch.reshape(var_node_9, [16])
var_node_11 = torch.full((16,), 2.4375, dtype=torch.bfloat16)
var_node_10 = torch.reshape(var_node_11, [16])
var_node_4 = torch.cat([var_node_5, var_node_8, var_node_10], dim=1)
var_node_12 = arg1 # size=(14, 48), stride=(48, 1), dtype=bfloat16
var_node_3 = torch.sub(var_node_4, var_node_12)
var_node_0 = torch.add(var_node_1, var_node_3)
var_node_14 = torch.full((14, 48), 1.4375, dtype=torch.bfloat16)
var_node_13 = torch.nn.functional.layer_norm(var_node_14, [48])
result = torch.add(var_node_0, var_node_13)
output = result + arg2 + arg3
return output
arg0 = torch.rand(
[1, 16], dtype=torch.bfloat16, device="cuda", requires_grad=True
)
arg1 = torch.rand(
[14, 48], dtype=torch.bfloat16, device="cuda", requires_grad=True
)
arg2 = torch.tensor(
0.0, dtype=torch.bfloat16, device="cuda", requires_grad=True
)
arg3 = torch.tensor(
0.0, dtype=torch.bfloat16, device="cuda", requires_grad=True
)
out_eager = foo(arg0, arg1, arg2, arg3)
out_eager.sum().backward()
print("Eager Success! ✅")
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
out_compiled = compiled_foo(arg0, arg1, arg2, arg3)
out_compiled.sum().backward()
print("Compile Success! ✅")
@pytest.mark.xfail(reason="Issue #164186")
def test_fuzzer_issue_164186(self):
torch.manual_seed(0)
def foo(arg0):
t0 = arg0 # size=(714, 33), stride=(33, 1), dtype=float16, device=cuda
t1 = t0.clone()
t1.zero_()
t2 = t1.contiguous().view((34, 9, 77))
t3 = t2.clone()
t3.zero_()
output = t3
return output
arg0 = torch.rand(
[714, 33], dtype=torch.float16, device="cuda", requires_grad=True
)
out_eager = foo(arg0)
out_eager.sum().backward()
print("Eager Success! ✅")
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
out_compiled = compiled_foo(arg0)
out_compiled.sum().backward()
print("Compile Success! ✅")
@pytest.mark.xfail(reason="Issue #164185")
def test_fuzzer_issue_164185(self):
torch.manual_seed(0)
def foo(arg0, arg1, arg2):
t0 = arg0 # size=(349200, 5), stride=(5, 1), dtype=bfloat16, device=cuda
t1 = t0.mean(
dim=1
) # size=(349200,), stride=(1,), dtype=bfloat16, device=cuda
t2 = arg1 # size=(), stride=(), dtype=int64, device=cuda
t3 = arg2 # size=(50000, 349200), stride=(50000, 1), dtype=bfloat16, device=cuda
t4 = torch.nn.functional.embedding(
torch.clamp(t2, 0, t3.size(0) - 1).to(torch.long), t3
)
t5 = torch.pow(torch.pow(torch.pow(torch.pow(t1, t4), t4), t1), t1)
t6 = t5.contiguous().view((75, 97, 48))
output = t6
return output
arg0 = torch.rand(
[349200, 5], dtype=torch.bfloat16, device="cuda", requires_grad=True
)
arg1 = torch.randint(0, 50000, [], dtype=torch.int64, device="cuda")
arg2 = torch.rand(
[50000, 349200], dtype=torch.bfloat16, device="cuda", requires_grad=True
)
out_eager = foo(arg0, arg1, arg2)
out_eager.sum().backward()
print("Eager Success! ✅")
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
out_compiled = compiled_foo(arg0, arg1, arg2)
out_compiled.sum().backward()
print("Compile Success! ✅")
@pytest.mark.xfail(reason="Issue #164157")
def test_fuzzer_issue_164157(self):
torch.manual_seed(0)
def foo(arg0, arg1, arg2, arg3, arg4, arg5):
t0 = arg0 # size=(47,), stride=(1,), dtype=int64, device=cuda
t1 = torch.tanh(t0) # size=(47,), stride=(1,), dtype=int64, device=cuda
t2 = arg1 # size=(), stride=(), dtype=int64, device=cuda
t3 = arg2 # size=(), stride=(), dtype=int64, device=cuda
t4 = t2 * t3 # size=(), stride=(), dtype=int64, device=cuda
t5 = t1.clone()
t5.fill_(t4.item())
t6 = (
arg3 # size=(256, 88, 1), stride=(88, 1, 1), dtype=float16, device=cuda
)
t7 = (
arg4 # size=(256, 88, 1), stride=(88, 1, 1), dtype=float16, device=cuda
)
t8 = (
arg5 # size=(256, 88, 1), stride=(88, 1, 1), dtype=float16, device=cuda
)
t9 = torch.cat([t6, t6, t7, t8], dim=2)
t10 = t9.std(dim=2)
t11 = torch.nn.functional.embedding(
torch.clamp(t5, 0, t10.size(0) - 1), t10
)
output = t11
return output
arg0 = torch.randint(0, 100, [47], dtype=torch.int64, device="cuda")
arg1 = torch.randint(0, 10, [], dtype=torch.int64, device="cuda")
arg2 = torch.randint(0, 10, [], dtype=torch.int64, device="cuda")
arg3 = torch.rand(
[256, 88, 1], dtype=torch.float16, device="cuda", requires_grad=True
)
arg4 = torch.rand(
[256, 88, 1], dtype=torch.float16, device="cuda", requires_grad=True
)
arg5 = torch.rand(
[256, 88, 1], dtype=torch.float16, device="cuda", requires_grad=True
)
out_eager = foo(arg0, arg1, arg2, arg3, arg4, arg5)
out_eager.sum().backward()
print("Eager Success! ✅")
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
out_compiled = compiled_foo(arg0, arg1, arg2, arg3, arg4, arg5)
out_compiled.sum().backward()
print("Compile Success! ✅")
@pytest.mark.xfail(reason="Issue #164428")
def test_fuzzer_issue_164428_already_exists(self):
torch.manual_seed(6804)
def foo(arg0, arg1, arg2):
var_node_4 = (
arg0 # size=(7, 1, 32), stride=(1, 1, 0), dtype=float64, device=cuda
)
var_node_5 = torch.full((7, 1, 32), -1.195053522845565, dtype=torch.float64)
var_node_3 = torch.div(var_node_4, var_node_5)
var_node_2 = torch.flatten(var_node_3)
var_node_8 = torch.full((2,), -0.8316502130341195, dtype=torch.float64)
var_node_9 = arg1 # size=(2, 224), stride=(224, 1), dtype=float64
var_node_7 = torch.matmul(
var_node_8.to(torch.float64), var_node_9.to(torch.float64)
)
var_node_10 = arg2 # size=(224,), stride=(1,), dtype=float64
var_node_6 = torch.sub(var_node_7, var_node_10)
var_node_1 = torch.sub(var_node_2, var_node_6)
output = var_node_1
return output
arg0 = torch.rand(
[7, 1, 32], dtype=torch.float64, device="cuda", requires_grad=True
)
arg1 = torch.rand(
[2, 224], dtype=torch.float64, device="cuda", requires_grad=True
)
arg2 = torch.rand([224], dtype=torch.float64, device="cuda", requires_grad=True)
out_eager = foo(arg0, arg1, arg2)
out_eager.sum().backward()
print("Eager Success! ✅")
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
out_compiled = compiled_foo(arg0, arg1, arg2)
out_compiled.sum().backward()
print("Compile Success! ✅")
@pytest.mark.xfail(reason="Issue #164086")
def test_fuzzer_issue_164086(self):
torch.manual_seed(0)
def foo(arg0, arg1, arg2, arg3, arg4, arg5):
t0 = arg0 # size=(42, 56), stride=(42, 1), dtype=int64, device=cuda
t1 = torch.tanh(
t0
) # size=(42, 56), stride=(42, 1), dtype=int64, device=cuda
t2 = t1.clone()
t2.zero_() # size=(42, 56), stride=(42, 1), dtype=int64, device=cuda
t3 = (
arg1 # size=(50000, 128), stride=(50000, 1), dtype=float16, device=cuda
)
t4 = arg2 # size=(46, 128), stride=(46, 1), dtype=float16, device=cuda
t5 = torch.nn.functional.linear(
t3, t4
) # size=(50000, 46), stride=(50000, 1), dtype=float16, device=cuda
t6 = arg3 # size=(50000, 4, 46), stride=(184, 46, 1), dtype=float16, device=cuda
t7 = t6.max(
dim=1
).values # size=(50000, 46), stride=(50000, 1), dtype=float16, device=cuda
t8 = arg4 # size=(25786, 46), stride=(46, 1), dtype=float16, device=cuda
t9 = arg5 # size=(24214, 46), stride=(46, 1), dtype=float16, device=cuda
t10 = torch.cat(
[t8, t9], dim=0
) # size=(50000, 46), stride=(50000, 1), dtype=float16, device=cuda
t11 = torch.pow(
torch.pow(torch.pow(torch.pow(t5, t7), t10), t5), t7
) # size=(50000, 46), stride=(50000, 1), dtype=float16, device=cuda
t12 = torch.nn.functional.embedding(
torch.clamp(t2, 0, t11.size(0) - 1).to(torch.long), t11
) # size=(42, 56, 46), stride=(2576, 46, 1), dtype=float16, device=cuda
output = t12
return output
arg0 = torch.randint(0, 1000, [42, 56], dtype=torch.int64, device="cuda")
arg1 = torch.rand(
[50000, 128], dtype=torch.float16, device="cuda", requires_grad=True
)
arg2 = torch.rand(
[46, 128], dtype=torch.float16, device="cuda", requires_grad=True
)
arg3 = torch.rand(
[50000, 4, 46], dtype=torch.float16, device="cuda", requires_grad=True
)
arg4 = torch.rand(
[25786, 46], dtype=torch.float16, device="cuda", requires_grad=True
)
arg5 = torch.rand(
[24214, 46], dtype=torch.float16, device="cuda", requires_grad=True
)
out_eager = foo(arg0, arg1, arg2, arg3, arg4, arg5)
out_eager.sum().backward()
print("Eager Success! ✅")
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
out_compiled = compiled_foo(arg0, arg1, arg2, arg3, arg4, arg5)
out_compiled.sum().backward()
print("Compile Success! ✅")
@pytest.mark.xfail(reason="Issue #163877")
def test_fuzzer_issue_163877(self):
torch.manual_seed(0)
def foo(arg0, arg1):
t0 = arg0 # size=(401120, 3), stride=(3, 1), dtype=float32, device=cuda
t1 = t0.clone()
t1.zero_() # size=(401120, 3), stride=(3, 1), dtype=float32, device=cuda
t2 = t1.reshape(
(109, 115, 96)
) # size=(109, 115, 96), stride=(11040, 96, 1), dtype=float32, device=cuda
t3 = arg1 # size=(), stride=(), dtype=float32, device=cuda
t4 = t3.contiguous() # size=(), stride=(), dtype=float32, device=cuda
t5 = torch.nn.functional.relu(
t4
) # size=(), stride=(), dtype=float32, device=cuda
t6 = t2.clone()
t6.fill_(
t5.item()
) # size=(109, 115, 96), stride=(11040, 96, 1), dtype=float32, device=cuda
output = t6
return output
arg0 = torch.rand(
[401120, 3], dtype=torch.float32, device="cuda", requires_grad=True
)
arg1 = torch.rand([], dtype=torch.float32, device="cuda", requires_grad=True)
out_eager = foo(arg0, arg1)
out_eager.sum().backward()
print("Eager Success! ✅")
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
out_compiled = compiled_foo(arg0, arg1)
out_compiled.sum().backward()
print("Compile Success! ✅")
@pytest.mark.xfail(reason="Issue #163971")
def test_fuzzer_issue_163971(self):
torch.manual_seed(0)
def foo(arg0):
t0 = arg0 # size=(), stride=(), dtype=bfloat16, device=cuda
t1 = torch.softmax(
t0, dim=0
) # size=(), stride=(), dtype=bfloat16, device=cuda
t2 = torch.nn.functional.gelu(
t1
) # size=(), stride=(), dtype=bfloat16, device=cuda
t3 = torch.softmax(
t2, dim=0
) # size=(), stride=(), dtype=bfloat16, device=cuda
output = t3
return output
arg0 = torch.rand([], dtype=torch.bfloat16, device="cuda", requires_grad=True)
out_eager = foo(arg0)
out_eager.sum().backward()
print("Eager Success! ✅")
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
out_compiled = compiled_foo(arg0)
out_compiled.sum().backward()
print("Compile Success! ✅")
@pytest.mark.xfail(reason="Issue #164059")
def test_fuzzer_issue_164059(self):
torch.manual_seed(0)
def foo(arg0, arg1, arg2):
t0 = arg0 # size=(16, 38073, 1), stride=(38073, 1, 1), dtype=float32, device=cuda
t1 = t0.clone()
t1.zero_() # size=(16, 38073, 1), stride=(38073, 1, 1), dtype=float32, device=cuda
t2 = t1.contiguous().view(
(49, 112, 111)
) # size=(49, 112, 111), stride=(5488, 112, 1), dtype=float32, device=cuda
t3 = arg1 # size=(1,), stride=(1,), dtype=int64, device=cuda
t4 = arg2 # size=(1,), stride=(1,), dtype=int64, device=cuda
t5 = t3 + t3 + t4 # size=(1,), stride=(1,), dtype=int64, device=cuda
t6 = torch.exp(t5) # size=(1,), stride=(1,), dtype=int64, device=cuda
t7 = torch.nn.functional.layer_norm(
t2, (111,)
) # size=(49, 112, 111), stride=(12432, 111, 1), dtype=float32, device=cuda
output = t7
return output
arg0 = torch.rand(
[16, 38073, 1], dtype=torch.float32, device="cuda", requires_grad=True
)
arg1 = torch.randint(0, 1000, [1], dtype=torch.int64, device="cuda")
arg2 = torch.randint(0, 1000, [1], dtype=torch.int64, device="cuda")
out_eager = foo(arg0, arg1, arg2)
out_eager.sum().backward()
print("Eager Success! ✅")
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
out_compiled = compiled_foo(arg0, arg1, arg2)
out_compiled.sum().backward()
print("Compile Success! ✅")
@pytest.mark.xfail(reason="Issue #164088")
def test_fuzzer_issue_164088(self):
torch.manual_seed(0)
def foo(arg0, arg1, arg2, arg3, arg4):
t0 = arg0 # size=(23, 4), stride=(4, 1), dtype=bfloat16, device=cuda
t1 = t0.clone()
t1.zero_() # size=(23, 4), stride=(4, 1), dtype=bfloat16, device=cuda
t2 = t1.contiguous().view(
(92,)
) # size=(92,), stride=(1,), dtype=bfloat16, device=cuda
t3 = arg1 # size=(5, 4, 5), stride=(20, 5, 1), dtype=bfloat16, device=cuda
t4 = t3.min() # size=(), stride=(), dtype=bfloat16, device=cuda
t5 = arg2 # size=(), stride=(), dtype=bfloat16, device=cuda
t6 = torch.nn.functional.silu(
t5
) # size=(), stride=(), dtype=bfloat16, device=cuda
t7 = arg3 # size=(3, 2, 3), stride=(6, 3, 1), dtype=bfloat16, device=cuda
t8 = t7.min() # size=(), stride=(), dtype=bfloat16, device=cuda
t9 = arg4 # size=(), stride=(), dtype=bfloat16, device=cuda
t10 = ((t8) / t9) / t9 # size=(), stride=(), dtype=bfloat16, device=cuda
t11 = (
t4 + t4 + t6 + t10 + t8
) # size=(), stride=(), dtype=bfloat16, device=cuda
t12 = t2.clone()
t12.fill_(
t11.item()
) # size=(92,), stride=(1,), dtype=bfloat16, device=cuda
output = t12
return output
arg0 = torch.rand(
[23, 4], dtype=torch.bfloat16, device="cuda", requires_grad=True
)
arg1 = torch.rand(
[5, 4, 5], dtype=torch.bfloat16, device="cuda", requires_grad=True
)
arg2 = torch.rand([], dtype=torch.bfloat16, device="cuda", requires_grad=True)
arg3 = torch.rand(
[3, 2, 3], dtype=torch.bfloat16, device="cuda", requires_grad=True
)
arg4 = torch.rand([], dtype=torch.bfloat16, device="cuda", requires_grad=True)
out_eager = foo(arg0, arg1, arg2, arg3, arg4)
out_eager.sum().backward()
print("Eager Success! ✅")
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
out_compiled = compiled_foo(arg0, arg1, arg2, arg3, arg4)
out_compiled.sum().backward()
print("Compile Success! ✅")
@pytest.mark.xfail(reason="Issue #163894")
def test_fuzzer_issue_163894(self):
torch.manual_seed(9)
def foo(arg0):
var_node_1 = arg0 # size=(1, 2), stride=(2, 1), dtype=int64, device=cuda
var_node_5 = torch.full(
(1, 2), -66, dtype=torch.int32
) # size=(1, 2), stride=(2, 1), dtype=int32, device=cuda
var_node_6 = torch.full(
(1, 2), 77, dtype=torch.int64
) # size=(1, 2), stride=(2, 1), dtype=int64, device=cuda
var_node_4 = torch.ops.aten.add(
var_node_5, var_node_6
) # size=(1, 2), stride=(2, 1), dtype=int32, device=cuda
var_node_7 = torch.full(
(1, 2), -64, dtype=torch.int32
) # size=(1, 2), stride=(2, 1), dtype=int32, device=cuda
var_node_3 = torch.ops.aten.mul(
var_node_4, var_node_7
) # size=(1, 2), stride=(2, 1), dtype=int32, device=cuda
var_node_9 = torch.full(
(3, 4), False, dtype=torch.bool
) # size=(3, 4), stride=(4, 1), dtype=bool, device=cuda
var_node_8 = torch.nonzero(
var_node_9
) # size=(0, 2), stride=(2, 1), dtype=int64, device=cuda
if var_node_8.numel() == 0:
var_node_8 = torch.zeros((1, 2), dtype=torch.int64, device="cuda")
var_node_2 = torch.ops.aten.add(var_node_3, var_node_8)
output = var_node_2.float()
return output
arg0 = torch.randint(0, 10, [1, 2], dtype=torch.int64, device="cuda")
out_eager = foo(arg0)
out_eager.sum().backward()
print("Eager Success! ✅")
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
out_compiled = compiled_foo(arg0)
out_compiled.sum().backward()
print("Compile Success! ✅")
@pytest.mark.xfail(reason="Issue #164486")
def test_fuzzer_issue_164486(self):
torch.manual_seed(238)
def foo(arg0):
var_node_2 = torch.full(
(), 1, dtype=torch.int16
) # size=(), stride=(), dtype=int16, device=cuda
var_node_3 = arg0 # size=(), stride=(), dtype=int16, device=cuda
var_node_1 = torch.add(
var_node_2, var_node_3
) # size=(), stride=(), dtype=int16, device=cuda
var_node_5 = torch.full(
(1,), 3, dtype=torch.int16
) # size=(1,), stride=(1,), dtype=int16, device=cuda
var_node_4 = torch.squeeze(
var_node_5
) # size=(), stride=(), dtype=int16, device=cuda
var_node_0 = torch.div(
var_node_1, var_node_4
) # size=(), stride=(), dtype=int16, device=cuda
result = var_node_0.float()
return result
arg0 = torch.randint(0, 10, [], dtype=torch.int16, device="cuda")
out_eager = foo(arg0)
out_eager.sum().backward()
print("Eager Success! ✅")
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
out_compiled = compiled_foo(arg0)
out_compiled.sum().backward()
print("Compile Success! ✅")
@pytest.mark.xfail(reason="Issue #163674")
def test_fuzzer_issue_163674(self):
torch.manual_seed(0)
def foo(arg0, arg1, arg2):
t0 = arg0 # size=(79488, 1, 3, 1), stride=(3, 3, 1, 1), dtype=float16, device=cuda
t1 = t0.clone()
t1.zero_() # size=(79488, 1, 3, 1), stride=(3, 3, 1, 1), dtype=float16, device=cuda
t2 = arg1 # size=(79488, 1, 3, 1), stride=(3, 3, 1, 1), dtype=float32, device=cuda
t3 = arg2 # size=(), stride=(), dtype=float32, device=cuda
t4 = t2.clone()
t4.fill_(
t3.item()
) # size=(79488, 1, 3, 1), stride=(3, 3, 1, 1), dtype=float32, device=cuda
t5 = torch.pow(
t1, t4
) # size=(79488, 1, 3, 1), stride=(3, 3, 1, 1), dtype=float32, device=cuda
t6 = t5.reshape(
(96, 69, 36)
) # size=(96, 69, 36), stride=(2484, 36, 1), dtype=float32, device=cuda
output = t6
return output
arg0 = torch.rand(
[79488, 1, 3, 1], dtype=torch.float16, device="cuda", requires_grad=True
)
arg1 = torch.rand(
[79488, 1, 3, 1], dtype=torch.float32, device="cuda", requires_grad=True
)
arg2 = torch.rand([], dtype=torch.float32, device="cuda", requires_grad=True)
out_eager = foo(arg0, arg1, arg2)
out_eager.sum().backward()
print("Eager Success! ✅")
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
out_compiled = compiled_foo(arg0, arg1, arg2)
out_compiled.sum().backward()
print("Compile Success! ✅")
if __name__ == "__main__":
run_tests()