diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index db950037a194..47692a4fa81b 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -46,6 +46,7 @@ from torch._dynamo.backends.debugging import ExplainWithBackend from torch._dynamo.debug_utils import same_two_models from torch._dynamo.testing import ( CompileCounter, + CompileCounterWithBackend, EagerAndRecordGraphs, rand_strided, same, @@ -54,6 +55,7 @@ from torch._dynamo.testing import ( ) from torch._inductor.utils import fresh_cache from torch.nn import functional as F +from torch.nn.attention.flex_attention import create_block_mask, flex_attention from torch.profiler import profile, ProfilerActivity from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, @@ -7369,6 +7371,67 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor): ) self.assertEqual(explain_output.break_reasons[0].reason, expected_msg) + @parametrize("backend", ["eager", "inductor"]) + def test_issue164247(self, backend: str): + if backend == "inductor" and torch._dynamo.config.dynamic_shapes: + raise unittest.SkipTest( + "Skip only in dynamic-shapes wrapper (known issue #157612)" + ) + + class MixedFakeModeModel(nn.Module): + def __init__(self, dim=64): + super().__init__() + self.dim = dim + self.lin = torch.nn.Linear(64, 64) + + def forward(self, x): + batch_size, seq_len, _ = x.shape + + # Process input first - this creates fake tensors in export's fake mode + processed = self.lin(x) + + # Create some computation that depends on processed tensor + intermediate = processed.sum(dim=-1).detach() # Shape: (batch, seq_len) + + def dynamic_mask_function(batch_idx, head_idx, q_idx, kv_idx): + threshold = intermediate[ + batch_idx, q_idx % seq_len + ] # Access the captured tensor + return (kv_idx <= q_idx) & (threshold > 0) + + block_mask = create_block_mask( + mask_mod=dynamic_mask_function, + B=batch_size, + H=None, + Q_LEN=seq_len, + KV_LEN=seq_len, + device=x.device, + _compile=False, + ) + q = processed.view(batch_size, 1, seq_len, self.dim) + k = processed.view(batch_size, 1, seq_len, self.dim) + v = processed.view(batch_size, 1, seq_len, self.dim) + + out = torch.compile(flex_attention)(q, k, v, block_mask=block_mask) + out = flex_attention(q, k, v, block_mask=block_mask) + + return out + + backend_counter = CompileCounterWithBackend(backend) + model = MixedFakeModeModel() + compiled = torch.compile(model, backend=backend_counter, fullgraph=True) + + if backend == "inductor": + # A known InductorError Issue https://github.com/pytorch/pytorch/issues/157612 + with self.assertRaises(RuntimeError): + compiled(torch.randn(2, 128, 64)) + else: + compiled(torch.randn(2, 128, 64)) + + # One graph, so no graph breaks + self.assertEqual(backend_counter.frame_count, 1) + self.assertEqual(len(backend_counter.graphs), 1) + # https://github.com/pytorch/pytorch/issues/164990 def test_guard_same_frame_fail_message(self): import torch._dynamo.guards as g diff --git a/test/dynamo_expected_failures/TestScript.test_python_frontend b/test/dynamo_expected_failures/TestScript.test_python_frontend deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestScript.test_python_frontend_py3 b/test/dynamo_expected_failures/TestScript.test_python_frontend_py3 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 7d534de073c9..4911ded6e333 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -1320,9 +1320,21 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable): def const_getattr(self, tx, name): if name == "__name__": - return self.fn_name.as_python_constant() + return self.get_name() + if name == "__code__": + return self.get_code() + if name == "__defaults__": + d = getattr(self, "defaults", None) + return d.as_python_constant() if d else None return super().const_getattr(tx, name) + def call_obj_hasattr(self, tx: "InstructionTranslator", name): + if name == "__code__": + return variables.ConstantVariable.create(hasattr(self, "code")) + if name == "__defaults__": + return variables.ConstantVariable.create(hasattr(self, "defaults")) + return super().call_obj_hasattr(tx, name) + def has_self(self): return False diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index a608020f30f3..0a4acdd7a232 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -267,11 +267,20 @@ def _get_mod_type(fn: Callable) -> _ModificationType: considered as a score_mod function. If the function has 4 positional arguments, it is considered as a mask function. """ - num_positional_args = sum( - 1 - for param in inspect.signature(fn).parameters.values() - if param.default is inspect.Parameter.empty - ) + if hasattr(fn, "__code__"): + code = fn.__code__ + num_positional_total = code.co_argcount + defaults = () + if hasattr(fn, "__defaults__"): + defaults = fn.__defaults__ or () + num_defaults = len(defaults) + num_positional_args = num_positional_total - num_defaults + else: + num_positional_args = sum( + 1 + for param in inspect.signature(fn).parameters.values() + if param.default is inspect.Parameter.empty + ) assert num_positional_args == 5 or num_positional_args == 4 if num_positional_args == 5: return _ModificationType.SCORE_MOD