mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			554 lines
		
	
	
		
			23 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			554 lines
		
	
	
		
			23 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Owner(s): ["module: tests"]
 | 
						|
"""
 | 
						|
Fuzzer-discovered eager/compile divergence test cases.
 | 
						|
 | 
						|
All tests are marked as xfail since they represent known compilation bugs.
 | 
						|
 | 
						|
IF YOU ARE HERE YOU LIKELY DIDN'T DO ANYTHING WRONG. In fact, you probably did something right!
 | 
						|
All of these tests are associated with bugs the fuzzer found. If one of these tests starts failing due to your PR,
 | 
						|
it actually means your PR fixed the bug! Feel free to delete the test and close out the issue linked from the test.
 | 
						|
"""
 | 
						|
 | 
						|
import pytest
 | 
						|
 | 
						|
import torch
 | 
						|
from torch.testing._internal.common_utils import run_tests, TestCase
 | 
						|
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
 | 
						|
 | 
						|
 | 
						|
class TestFuzzerCompileIssues(TestCase):
 | 
						|
    """Test cases for fuzzer-discovered eager/compile divergence issues."""
 | 
						|
 | 
						|
    def setUp(self):
 | 
						|
        """Configure common test settings."""
 | 
						|
        torch._dynamo.config.capture_scalar_outputs = True
 | 
						|
        torch._dynamo.config.capture_dynamic_output_shape_ops = True
 | 
						|
        torch._inductor.config.emulate_precision_casts = True
 | 
						|
 | 
						|
    @pytest.mark.xfail(reason="Issue #164484")
 | 
						|
    def test_fuzzer_issue_164484(self):
 | 
						|
        torch.manual_seed(9157)
 | 
						|
 | 
						|
        def foo(arg0, arg1, arg2, arg3):
 | 
						|
            var_node_2 = torch.full((14, 16), 1.158473253250122, dtype=torch.float32)
 | 
						|
            var_node_1 = torch.nn.functional.relu(var_node_2)
 | 
						|
            var_node_6 = torch.full((14, 1), -0.94140625, dtype=torch.bfloat16)
 | 
						|
            var_node_7 = arg0  # size=(1, 16), stride=(16, 1), dtype=bfloat16
 | 
						|
            var_node_5 = torch.matmul(
 | 
						|
                var_node_6.to(torch.bfloat16), var_node_7.to(torch.bfloat16)
 | 
						|
            )
 | 
						|
            var_node_9 = torch.full((16,), 0.76953125, dtype=torch.bfloat16)
 | 
						|
            var_node_8 = torch.reshape(var_node_9, [16])
 | 
						|
            var_node_11 = torch.full((16,), 2.4375, dtype=torch.bfloat16)
 | 
						|
            var_node_10 = torch.reshape(var_node_11, [16])
 | 
						|
            var_node_4 = torch.cat([var_node_5, var_node_8, var_node_10], dim=1)
 | 
						|
            var_node_12 = arg1  # size=(14, 48), stride=(48, 1), dtype=bfloat16
 | 
						|
            var_node_3 = torch.sub(var_node_4, var_node_12)
 | 
						|
            var_node_0 = torch.add(var_node_1, var_node_3)
 | 
						|
            var_node_14 = torch.full((14, 48), 1.4375, dtype=torch.bfloat16)
 | 
						|
            var_node_13 = torch.nn.functional.layer_norm(var_node_14, [48])
 | 
						|
            result = torch.add(var_node_0, var_node_13)
 | 
						|
            output = result + arg2 + arg3
 | 
						|
            return output
 | 
						|
 | 
						|
        arg0 = torch.rand(
 | 
						|
            [1, 16], dtype=torch.bfloat16, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
        arg1 = torch.rand(
 | 
						|
            [14, 48], dtype=torch.bfloat16, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
        arg2 = torch.tensor(
 | 
						|
            0.0, dtype=torch.bfloat16, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
        arg3 = torch.tensor(
 | 
						|
            0.0, dtype=torch.bfloat16, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
 | 
						|
        out_eager = foo(arg0, arg1, arg2, arg3)
 | 
						|
        out_eager.sum().backward()
 | 
						|
        print("Eager Success! ✅")
 | 
						|
        compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
 | 
						|
        out_compiled = compiled_foo(arg0, arg1, arg2, arg3)
 | 
						|
        out_compiled.sum().backward()
 | 
						|
        print("Compile Success! ✅")
 | 
						|
 | 
						|
    @pytest.mark.xfail(reason="Issue #164186")
 | 
						|
    def test_fuzzer_issue_164186(self):
 | 
						|
        torch.manual_seed(0)
 | 
						|
 | 
						|
        def foo(arg0):
 | 
						|
            t0 = arg0  # size=(714, 33), stride=(33, 1), dtype=float16, device=cuda
 | 
						|
            t1 = t0.clone()
 | 
						|
            t1.zero_()
 | 
						|
            t2 = t1.contiguous().view((34, 9, 77))
 | 
						|
            t3 = t2.clone()
 | 
						|
            t3.zero_()
 | 
						|
            output = t3
 | 
						|
            return output
 | 
						|
 | 
						|
        arg0 = torch.rand(
 | 
						|
            [714, 33], dtype=torch.float16, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
 | 
						|
        out_eager = foo(arg0)
 | 
						|
        out_eager.sum().backward()
 | 
						|
        print("Eager Success! ✅")
 | 
						|
        compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
 | 
						|
        out_compiled = compiled_foo(arg0)
 | 
						|
        out_compiled.sum().backward()
 | 
						|
        print("Compile Success! ✅")
 | 
						|
 | 
						|
    @pytest.mark.xfail(reason="Issue #164185")
 | 
						|
    def test_fuzzer_issue_164185(self):
 | 
						|
        torch.manual_seed(0)
 | 
						|
 | 
						|
        def foo(arg0, arg1, arg2):
 | 
						|
            t0 = arg0  # size=(349200, 5), stride=(5, 1), dtype=bfloat16, device=cuda
 | 
						|
            t1 = t0.mean(
 | 
						|
                dim=1
 | 
						|
            )  # size=(349200,), stride=(1,), dtype=bfloat16, device=cuda
 | 
						|
            t2 = arg1  # size=(), stride=(), dtype=int64, device=cuda
 | 
						|
            t3 = arg2  # size=(50000, 349200), stride=(50000, 1), dtype=bfloat16, device=cuda
 | 
						|
            t4 = torch.nn.functional.embedding(
 | 
						|
                torch.clamp(t2, 0, t3.size(0) - 1).to(torch.long), t3
 | 
						|
            )
 | 
						|
            t5 = torch.pow(torch.pow(torch.pow(torch.pow(t1, t4), t4), t1), t1)
 | 
						|
            t6 = t5.contiguous().view((75, 97, 48))
 | 
						|
            output = t6
 | 
						|
            return output
 | 
						|
 | 
						|
        arg0 = torch.rand(
 | 
						|
            [349200, 5], dtype=torch.bfloat16, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
        arg1 = torch.randint(0, 50000, [], dtype=torch.int64, device="cuda")
 | 
						|
        arg2 = torch.rand(
 | 
						|
            [50000, 349200], dtype=torch.bfloat16, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
 | 
						|
        out_eager = foo(arg0, arg1, arg2)
 | 
						|
        out_eager.sum().backward()
 | 
						|
        print("Eager Success! ✅")
 | 
						|
        compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
 | 
						|
        out_compiled = compiled_foo(arg0, arg1, arg2)
 | 
						|
        out_compiled.sum().backward()
 | 
						|
        print("Compile Success! ✅")
 | 
						|
 | 
						|
    @pytest.mark.xfail(reason="Issue #164157")
 | 
						|
    def test_fuzzer_issue_164157(self):
 | 
						|
        torch.manual_seed(0)
 | 
						|
 | 
						|
        def foo(arg0, arg1, arg2, arg3, arg4, arg5):
 | 
						|
            t0 = arg0  # size=(47,), stride=(1,), dtype=int64, device=cuda
 | 
						|
            t1 = torch.tanh(t0)  # size=(47,), stride=(1,), dtype=int64, device=cuda
 | 
						|
            t2 = arg1  # size=(), stride=(), dtype=int64, device=cuda
 | 
						|
            t3 = arg2  # size=(), stride=(), dtype=int64, device=cuda
 | 
						|
            t4 = t2 * t3  # size=(), stride=(), dtype=int64, device=cuda
 | 
						|
            t5 = t1.clone()
 | 
						|
            t5.fill_(t4.item())
 | 
						|
            t6 = (
 | 
						|
                arg3  # size=(256, 88, 1), stride=(88, 1, 1), dtype=float16, device=cuda
 | 
						|
            )
 | 
						|
            t7 = (
 | 
						|
                arg4  # size=(256, 88, 1), stride=(88, 1, 1), dtype=float16, device=cuda
 | 
						|
            )
 | 
						|
            t8 = (
 | 
						|
                arg5  # size=(256, 88, 1), stride=(88, 1, 1), dtype=float16, device=cuda
 | 
						|
            )
 | 
						|
            t9 = torch.cat([t6, t6, t7, t8], dim=2)
 | 
						|
            t10 = t9.std(dim=2)
 | 
						|
            t11 = torch.nn.functional.embedding(
 | 
						|
                torch.clamp(t5, 0, t10.size(0) - 1), t10
 | 
						|
            )
 | 
						|
            output = t11
 | 
						|
            return output
 | 
						|
 | 
						|
        arg0 = torch.randint(0, 100, [47], dtype=torch.int64, device="cuda")
 | 
						|
        arg1 = torch.randint(0, 10, [], dtype=torch.int64, device="cuda")
 | 
						|
        arg2 = torch.randint(0, 10, [], dtype=torch.int64, device="cuda")
 | 
						|
        arg3 = torch.rand(
 | 
						|
            [256, 88, 1], dtype=torch.float16, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
        arg4 = torch.rand(
 | 
						|
            [256, 88, 1], dtype=torch.float16, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
        arg5 = torch.rand(
 | 
						|
            [256, 88, 1], dtype=torch.float16, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
 | 
						|
        out_eager = foo(arg0, arg1, arg2, arg3, arg4, arg5)
 | 
						|
        out_eager.sum().backward()
 | 
						|
        print("Eager Success! ✅")
 | 
						|
        compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
 | 
						|
        out_compiled = compiled_foo(arg0, arg1, arg2, arg3, arg4, arg5)
 | 
						|
        out_compiled.sum().backward()
 | 
						|
        print("Compile Success! ✅")
 | 
						|
 | 
						|
    @pytest.mark.xfail(reason="Issue #164428")
 | 
						|
    def test_fuzzer_issue_164428_already_exists(self):
 | 
						|
        torch.manual_seed(6804)
 | 
						|
 | 
						|
        def foo(arg0, arg1, arg2):
 | 
						|
            var_node_4 = (
 | 
						|
                arg0  # size=(7, 1, 32), stride=(1, 1, 0), dtype=float64, device=cuda
 | 
						|
            )
 | 
						|
            var_node_5 = torch.full((7, 1, 32), -1.195053522845565, dtype=torch.float64)
 | 
						|
            var_node_3 = torch.div(var_node_4, var_node_5)
 | 
						|
            var_node_2 = torch.flatten(var_node_3)
 | 
						|
            var_node_8 = torch.full((2,), -0.8316502130341195, dtype=torch.float64)
 | 
						|
            var_node_9 = arg1  # size=(2, 224), stride=(224, 1), dtype=float64
 | 
						|
            var_node_7 = torch.matmul(
 | 
						|
                var_node_8.to(torch.float64), var_node_9.to(torch.float64)
 | 
						|
            )
 | 
						|
            var_node_10 = arg2  # size=(224,), stride=(1,), dtype=float64
 | 
						|
            var_node_6 = torch.sub(var_node_7, var_node_10)
 | 
						|
            var_node_1 = torch.sub(var_node_2, var_node_6)
 | 
						|
            output = var_node_1
 | 
						|
            return output
 | 
						|
 | 
						|
        arg0 = torch.rand(
 | 
						|
            [7, 1, 32], dtype=torch.float64, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
        arg1 = torch.rand(
 | 
						|
            [2, 224], dtype=torch.float64, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
        arg2 = torch.rand([224], dtype=torch.float64, device="cuda", requires_grad=True)
 | 
						|
 | 
						|
        out_eager = foo(arg0, arg1, arg2)
 | 
						|
        out_eager.sum().backward()
 | 
						|
        print("Eager Success! ✅")
 | 
						|
        compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
 | 
						|
        out_compiled = compiled_foo(arg0, arg1, arg2)
 | 
						|
        out_compiled.sum().backward()
 | 
						|
        print("Compile Success! ✅")
 | 
						|
 | 
						|
    @pytest.mark.skipUnless(HAS_CUDA_AND_TRITON, "Needs CUDA")
 | 
						|
    def test_fuzzer_issue_164086(self):
 | 
						|
        torch.manual_seed(0)
 | 
						|
 | 
						|
        def foo(arg0, arg1, arg2, arg3, arg4, arg5):
 | 
						|
            t0 = arg0  # size=(42, 56), stride=(42, 1), dtype=int64, device=cuda
 | 
						|
            t1 = torch.tanh(
 | 
						|
                t0
 | 
						|
            )  # size=(42, 56), stride=(42, 1), dtype=int64, device=cuda
 | 
						|
            t2 = t1.clone()
 | 
						|
            t2.zero_()  # size=(42, 56), stride=(42, 1), dtype=int64, device=cuda
 | 
						|
            t3 = (
 | 
						|
                arg1  # size=(50000, 128), stride=(50000, 1), dtype=float16, device=cuda
 | 
						|
            )
 | 
						|
            t4 = arg2  # size=(46, 128), stride=(46, 1), dtype=float16, device=cuda
 | 
						|
            t5 = torch.nn.functional.linear(
 | 
						|
                t3, t4
 | 
						|
            )  # size=(50000, 46), stride=(50000, 1), dtype=float16, device=cuda
 | 
						|
            t6 = arg3  # size=(50000, 4, 46), stride=(184, 46, 1), dtype=float16, device=cuda
 | 
						|
            t7 = t6.max(
 | 
						|
                dim=1
 | 
						|
            ).values  # size=(50000, 46), stride=(50000, 1), dtype=float16, device=cuda
 | 
						|
            t8 = arg4  # size=(25786, 46), stride=(46, 1), dtype=float16, device=cuda
 | 
						|
            t9 = arg5  # size=(24214, 46), stride=(46, 1), dtype=float16, device=cuda
 | 
						|
            t10 = torch.cat(
 | 
						|
                [t8, t9], dim=0
 | 
						|
            )  # size=(50000, 46), stride=(50000, 1), dtype=float16, device=cuda
 | 
						|
            t11 = torch.pow(
 | 
						|
                torch.pow(torch.pow(torch.pow(t5, t7), t10), t5), t7
 | 
						|
            )  # size=(50000, 46), stride=(50000, 1), dtype=float16, device=cuda
 | 
						|
            t12 = torch.nn.functional.embedding(
 | 
						|
                torch.clamp(t2, 0, t11.size(0) - 1).to(torch.long), t11
 | 
						|
            )  # size=(42, 56, 46), stride=(2576, 46, 1), dtype=float16, device=cuda
 | 
						|
            output = t12
 | 
						|
            return output
 | 
						|
 | 
						|
        arg0 = torch.randint(0, 1000, [42, 56], dtype=torch.int64, device="cuda")
 | 
						|
        arg1 = torch.rand(
 | 
						|
            [50000, 128], dtype=torch.float16, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
        arg2 = torch.rand(
 | 
						|
            [46, 128], dtype=torch.float16, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
        arg3 = torch.rand(
 | 
						|
            [50000, 4, 46], dtype=torch.float16, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
        arg4 = torch.rand(
 | 
						|
            [25786, 46], dtype=torch.float16, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
        arg5 = torch.rand(
 | 
						|
            [24214, 46], dtype=torch.float16, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
 | 
						|
        out_eager = foo(arg0, arg1, arg2, arg3, arg4, arg5)
 | 
						|
        out_eager.sum().backward()
 | 
						|
        print("Eager Success! ✅")
 | 
						|
        compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
 | 
						|
        out_compiled = compiled_foo(arg0, arg1, arg2, arg3, arg4, arg5)
 | 
						|
        out_compiled.sum().backward()
 | 
						|
        print("Compile Success! ✅")
 | 
						|
 | 
						|
    @pytest.mark.xfail(reason="Issue #163877")
 | 
						|
    def test_fuzzer_issue_163877(self):
 | 
						|
        torch.manual_seed(0)
 | 
						|
 | 
						|
        def foo(arg0, arg1):
 | 
						|
            t0 = arg0  # size=(401120, 3), stride=(3, 1), dtype=float32, device=cuda
 | 
						|
            t1 = t0.clone()
 | 
						|
            t1.zero_()  # size=(401120, 3), stride=(3, 1), dtype=float32, device=cuda
 | 
						|
            t2 = t1.reshape(
 | 
						|
                (109, 115, 96)
 | 
						|
            )  # size=(109, 115, 96), stride=(11040, 96, 1), dtype=float32, device=cuda
 | 
						|
            t3 = arg1  # size=(), stride=(), dtype=float32, device=cuda
 | 
						|
            t4 = t3.contiguous()  # size=(), stride=(), dtype=float32, device=cuda
 | 
						|
            t5 = torch.nn.functional.relu(
 | 
						|
                t4
 | 
						|
            )  # size=(), stride=(), dtype=float32, device=cuda
 | 
						|
            t6 = t2.clone()
 | 
						|
            t6.fill_(
 | 
						|
                t5.item()
 | 
						|
            )  # size=(109, 115, 96), stride=(11040, 96, 1), dtype=float32, device=cuda
 | 
						|
            output = t6
 | 
						|
            return output
 | 
						|
 | 
						|
        arg0 = torch.rand(
 | 
						|
            [401120, 3], dtype=torch.float32, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
        arg1 = torch.rand([], dtype=torch.float32, device="cuda", requires_grad=True)
 | 
						|
 | 
						|
        out_eager = foo(arg0, arg1)
 | 
						|
        out_eager.sum().backward()
 | 
						|
        print("Eager Success! ✅")
 | 
						|
        compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
 | 
						|
        out_compiled = compiled_foo(arg0, arg1)
 | 
						|
        out_compiled.sum().backward()
 | 
						|
        print("Compile Success! ✅")
 | 
						|
 | 
						|
    @pytest.mark.xfail(reason="Issue #163971")
 | 
						|
    def test_fuzzer_issue_163971(self):
 | 
						|
        torch.manual_seed(0)
 | 
						|
 | 
						|
        def foo(arg0):
 | 
						|
            t0 = arg0  # size=(), stride=(), dtype=bfloat16, device=cuda
 | 
						|
            t1 = torch.softmax(
 | 
						|
                t0, dim=0
 | 
						|
            )  # size=(), stride=(), dtype=bfloat16, device=cuda
 | 
						|
            t2 = torch.nn.functional.gelu(
 | 
						|
                t1
 | 
						|
            )  # size=(), stride=(), dtype=bfloat16, device=cuda
 | 
						|
            t3 = torch.softmax(
 | 
						|
                t2, dim=0
 | 
						|
            )  # size=(), stride=(), dtype=bfloat16, device=cuda
 | 
						|
            output = t3
 | 
						|
            return output
 | 
						|
 | 
						|
        arg0 = torch.rand([], dtype=torch.bfloat16, device="cuda", requires_grad=True)
 | 
						|
 | 
						|
        out_eager = foo(arg0)
 | 
						|
        out_eager.sum().backward()
 | 
						|
        print("Eager Success! ✅")
 | 
						|
        compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
 | 
						|
        out_compiled = compiled_foo(arg0)
 | 
						|
        out_compiled.sum().backward()
 | 
						|
        print("Compile Success! ✅")
 | 
						|
 | 
						|
    @pytest.mark.xfail(reason="Issue #164059")
 | 
						|
    def test_fuzzer_issue_164059(self):
 | 
						|
        torch.manual_seed(0)
 | 
						|
 | 
						|
        def foo(arg0, arg1, arg2):
 | 
						|
            t0 = arg0  # size=(16, 38073, 1), stride=(38073, 1, 1), dtype=float32, device=cuda
 | 
						|
            t1 = t0.clone()
 | 
						|
            t1.zero_()  # size=(16, 38073, 1), stride=(38073, 1, 1), dtype=float32, device=cuda
 | 
						|
            t2 = t1.contiguous().view(
 | 
						|
                (49, 112, 111)
 | 
						|
            )  # size=(49, 112, 111), stride=(5488, 112, 1), dtype=float32, device=cuda
 | 
						|
            t3 = arg1  # size=(1,), stride=(1,), dtype=int64, device=cuda
 | 
						|
            t4 = arg2  # size=(1,), stride=(1,), dtype=int64, device=cuda
 | 
						|
            t5 = t3 + t3 + t4  # size=(1,), stride=(1,), dtype=int64, device=cuda
 | 
						|
            t6 = torch.exp(t5)  # size=(1,), stride=(1,), dtype=int64, device=cuda
 | 
						|
            t7 = torch.nn.functional.layer_norm(
 | 
						|
                t2, (111,)
 | 
						|
            )  # size=(49, 112, 111), stride=(12432, 111, 1), dtype=float32, device=cuda
 | 
						|
            output = t7
 | 
						|
            return output
 | 
						|
 | 
						|
        arg0 = torch.rand(
 | 
						|
            [16, 38073, 1], dtype=torch.float32, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
        arg1 = torch.randint(0, 1000, [1], dtype=torch.int64, device="cuda")
 | 
						|
        arg2 = torch.randint(0, 1000, [1], dtype=torch.int64, device="cuda")
 | 
						|
 | 
						|
        out_eager = foo(arg0, arg1, arg2)
 | 
						|
        out_eager.sum().backward()
 | 
						|
        print("Eager Success! ✅")
 | 
						|
        compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
 | 
						|
        out_compiled = compiled_foo(arg0, arg1, arg2)
 | 
						|
        out_compiled.sum().backward()
 | 
						|
        print("Compile Success! ✅")
 | 
						|
 | 
						|
    @pytest.mark.xfail(reason="Issue #164088")
 | 
						|
    def test_fuzzer_issue_164088(self):
 | 
						|
        torch.manual_seed(0)
 | 
						|
 | 
						|
        def foo(arg0, arg1, arg2, arg3, arg4):
 | 
						|
            t0 = arg0  # size=(23, 4), stride=(4, 1), dtype=bfloat16, device=cuda
 | 
						|
            t1 = t0.clone()
 | 
						|
            t1.zero_()  # size=(23, 4), stride=(4, 1), dtype=bfloat16, device=cuda
 | 
						|
            t2 = t1.contiguous().view(
 | 
						|
                (92,)
 | 
						|
            )  # size=(92,), stride=(1,), dtype=bfloat16, device=cuda
 | 
						|
            t3 = arg1  # size=(5, 4, 5), stride=(20, 5, 1), dtype=bfloat16, device=cuda
 | 
						|
            t4 = t3.min()  # size=(), stride=(), dtype=bfloat16, device=cuda
 | 
						|
            t5 = arg2  # size=(), stride=(), dtype=bfloat16, device=cuda
 | 
						|
            t6 = torch.nn.functional.silu(
 | 
						|
                t5
 | 
						|
            )  # size=(), stride=(), dtype=bfloat16, device=cuda
 | 
						|
            t7 = arg3  # size=(3, 2, 3), stride=(6, 3, 1), dtype=bfloat16, device=cuda
 | 
						|
            t8 = t7.min()  # size=(), stride=(), dtype=bfloat16, device=cuda
 | 
						|
            t9 = arg4  # size=(), stride=(), dtype=bfloat16, device=cuda
 | 
						|
            t10 = ((t8) / t9) / t9  # size=(), stride=(), dtype=bfloat16, device=cuda
 | 
						|
            t11 = (
 | 
						|
                t4 + t4 + t6 + t10 + t8
 | 
						|
            )  # size=(), stride=(), dtype=bfloat16, device=cuda
 | 
						|
            t12 = t2.clone()
 | 
						|
            t12.fill_(
 | 
						|
                t11.item()
 | 
						|
            )  # size=(92,), stride=(1,), dtype=bfloat16, device=cuda
 | 
						|
            output = t12
 | 
						|
            return output
 | 
						|
 | 
						|
        arg0 = torch.rand(
 | 
						|
            [23, 4], dtype=torch.bfloat16, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
        arg1 = torch.rand(
 | 
						|
            [5, 4, 5], dtype=torch.bfloat16, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
        arg2 = torch.rand([], dtype=torch.bfloat16, device="cuda", requires_grad=True)
 | 
						|
        arg3 = torch.rand(
 | 
						|
            [3, 2, 3], dtype=torch.bfloat16, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
        arg4 = torch.rand([], dtype=torch.bfloat16, device="cuda", requires_grad=True)
 | 
						|
 | 
						|
        out_eager = foo(arg0, arg1, arg2, arg3, arg4)
 | 
						|
        out_eager.sum().backward()
 | 
						|
        print("Eager Success! ✅")
 | 
						|
        compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
 | 
						|
        out_compiled = compiled_foo(arg0, arg1, arg2, arg3, arg4)
 | 
						|
        out_compiled.sum().backward()
 | 
						|
        print("Compile Success! ✅")
 | 
						|
 | 
						|
    @pytest.mark.xfail(reason="Issue #163894")
 | 
						|
    def test_fuzzer_issue_163894(self):
 | 
						|
        torch.manual_seed(9)
 | 
						|
 | 
						|
        def foo(arg0):
 | 
						|
            var_node_1 = arg0  # size=(1, 2), stride=(2, 1), dtype=int64, device=cuda
 | 
						|
            var_node_5 = torch.full(
 | 
						|
                (1, 2), -66, dtype=torch.int32
 | 
						|
            )  # size=(1, 2), stride=(2, 1), dtype=int32, device=cuda
 | 
						|
            var_node_6 = torch.full(
 | 
						|
                (1, 2), 77, dtype=torch.int64
 | 
						|
            )  # size=(1, 2), stride=(2, 1), dtype=int64, device=cuda
 | 
						|
            var_node_4 = torch.ops.aten.add(
 | 
						|
                var_node_5, var_node_6
 | 
						|
            )  # size=(1, 2), stride=(2, 1), dtype=int32, device=cuda
 | 
						|
            var_node_7 = torch.full(
 | 
						|
                (1, 2), -64, dtype=torch.int32
 | 
						|
            )  # size=(1, 2), stride=(2, 1), dtype=int32, device=cuda
 | 
						|
            var_node_3 = torch.ops.aten.mul(
 | 
						|
                var_node_4, var_node_7
 | 
						|
            )  # size=(1, 2), stride=(2, 1), dtype=int32, device=cuda
 | 
						|
            var_node_9 = torch.full(
 | 
						|
                (3, 4), False, dtype=torch.bool
 | 
						|
            )  # size=(3, 4), stride=(4, 1), dtype=bool, device=cuda
 | 
						|
            var_node_8 = torch.nonzero(
 | 
						|
                var_node_9
 | 
						|
            )  # size=(0, 2), stride=(2, 1), dtype=int64, device=cuda
 | 
						|
            if var_node_8.numel() == 0:
 | 
						|
                var_node_8 = torch.zeros((1, 2), dtype=torch.int64, device="cuda")
 | 
						|
            var_node_2 = torch.ops.aten.add(var_node_3, var_node_8)
 | 
						|
            output = var_node_2.float()
 | 
						|
            return output
 | 
						|
 | 
						|
        arg0 = torch.randint(0, 10, [1, 2], dtype=torch.int64, device="cuda")
 | 
						|
 | 
						|
        out_eager = foo(arg0)
 | 
						|
        out_eager.sum().backward()
 | 
						|
        print("Eager Success! ✅")
 | 
						|
        compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
 | 
						|
        out_compiled = compiled_foo(arg0)
 | 
						|
        out_compiled.sum().backward()
 | 
						|
        print("Compile Success! ✅")
 | 
						|
 | 
						|
    @pytest.mark.xfail(reason="Issue #164486")
 | 
						|
    def test_fuzzer_issue_164486(self):
 | 
						|
        torch.manual_seed(238)
 | 
						|
 | 
						|
        def foo(arg0):
 | 
						|
            var_node_2 = torch.full(
 | 
						|
                (), 1, dtype=torch.int16
 | 
						|
            )  # size=(), stride=(), dtype=int16, device=cuda
 | 
						|
            var_node_3 = arg0  # size=(), stride=(), dtype=int16, device=cuda
 | 
						|
            var_node_1 = torch.add(
 | 
						|
                var_node_2, var_node_3
 | 
						|
            )  # size=(), stride=(), dtype=int16, device=cuda
 | 
						|
            var_node_5 = torch.full(
 | 
						|
                (1,), 3, dtype=torch.int16
 | 
						|
            )  # size=(1,), stride=(1,), dtype=int16, device=cuda
 | 
						|
            var_node_4 = torch.squeeze(
 | 
						|
                var_node_5
 | 
						|
            )  # size=(), stride=(), dtype=int16, device=cuda
 | 
						|
            var_node_0 = torch.div(
 | 
						|
                var_node_1, var_node_4
 | 
						|
            )  # size=(), stride=(), dtype=int16, device=cuda
 | 
						|
            result = var_node_0.float()
 | 
						|
            return result
 | 
						|
 | 
						|
        arg0 = torch.randint(0, 10, [], dtype=torch.int16, device="cuda")
 | 
						|
 | 
						|
        out_eager = foo(arg0)
 | 
						|
        out_eager.sum().backward()
 | 
						|
        print("Eager Success! ✅")
 | 
						|
        compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
 | 
						|
        out_compiled = compiled_foo(arg0)
 | 
						|
        out_compiled.sum().backward()
 | 
						|
        print("Compile Success! ✅")
 | 
						|
 | 
						|
    @pytest.mark.xfail(reason="Issue #163674")
 | 
						|
    def test_fuzzer_issue_163674(self):
 | 
						|
        torch.manual_seed(0)
 | 
						|
 | 
						|
        def foo(arg0, arg1, arg2):
 | 
						|
            t0 = arg0  # size=(79488, 1, 3, 1), stride=(3, 3, 1, 1), dtype=float16, device=cuda
 | 
						|
            t1 = t0.clone()
 | 
						|
            t1.zero_()  # size=(79488, 1, 3, 1), stride=(3, 3, 1, 1), dtype=float16, device=cuda
 | 
						|
            t2 = arg1  # size=(79488, 1, 3, 1), stride=(3, 3, 1, 1), dtype=float32, device=cuda
 | 
						|
            t3 = arg2  # size=(), stride=(), dtype=float32, device=cuda
 | 
						|
            t4 = t2.clone()
 | 
						|
            t4.fill_(
 | 
						|
                t3.item()
 | 
						|
            )  # size=(79488, 1, 3, 1), stride=(3, 3, 1, 1), dtype=float32, device=cuda
 | 
						|
            t5 = torch.pow(
 | 
						|
                t1, t4
 | 
						|
            )  # size=(79488, 1, 3, 1), stride=(3, 3, 1, 1), dtype=float32, device=cuda
 | 
						|
            t6 = t5.reshape(
 | 
						|
                (96, 69, 36)
 | 
						|
            )  # size=(96, 69, 36), stride=(2484, 36, 1), dtype=float32, device=cuda
 | 
						|
            output = t6
 | 
						|
            return output
 | 
						|
 | 
						|
        arg0 = torch.rand(
 | 
						|
            [79488, 1, 3, 1], dtype=torch.float16, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
        arg1 = torch.rand(
 | 
						|
            [79488, 1, 3, 1], dtype=torch.float32, device="cuda", requires_grad=True
 | 
						|
        )
 | 
						|
        arg2 = torch.rand([], dtype=torch.float32, device="cuda", requires_grad=True)
 | 
						|
 | 
						|
        out_eager = foo(arg0, arg1, arg2)
 | 
						|
        out_eager.sum().backward()
 | 
						|
        print("Eager Success! ✅")
 | 
						|
        compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
 | 
						|
        out_compiled = compiled_foo(arg0, arg1, arg2)
 | 
						|
        out_compiled.sum().backward()
 | 
						|
        print("Compile Success! ✅")
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    run_tests()
 |