From cff1b207717b84b6ac3fdc95fc5ac91cc3802b63 Mon Sep 17 00:00:00 2001 From: jmaczan Date: Fri, 17 Oct 2025 17:44:43 +0000 Subject: [PATCH] Patch the flex_attention._get_mod_type to not use inspect.signature when computing num_positional_args (an alternative fix for flex attention graph break on create_block_mask) (#164923) The initial fix for inspect.signature uses not a right approach (https://github.com/pytorch/pytorch/pull/164349#pullrequestreview-3306614010). As @williamwen42 suggests (https://github.com/pytorch/pytorch/pull/164349#issuecomment-3379222885) we can just for now get rid of `inspect.signature` call in flex_attention to resolve this high priority issue (https://github.com/pytorch/pytorch/issues/164247#issuecomment-3378673179). In this PR I did exactly this - limited the scope of fix to just computing `num_positional_args` in `flex_attention._get_mod_type` based on properties returned by `NestedUserFunctionVariable.const_getattr` (some were missing so I added them) Fixes #164247 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164923 Approved by: https://github.com/williamwen42 --- test/dynamo/test_repros.py | 63 +++++++++++++++++++ .../TestScript.test_python_frontend | 0 .../TestScript.test_python_frontend_py3 | 0 torch/_dynamo/variables/functions.py | 14 ++++- torch/nn/attention/flex_attention.py | 19 ++++-- 5 files changed, 90 insertions(+), 6 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestScript.test_python_frontend delete mode 100644 test/dynamo_expected_failures/TestScript.test_python_frontend_py3 diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index db950037a194..47692a4fa81b 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -46,6 +46,7 @@ from torch._dynamo.backends.debugging import ExplainWithBackend from torch._dynamo.debug_utils import same_two_models from torch._dynamo.testing import ( CompileCounter, + CompileCounterWithBackend, EagerAndRecordGraphs, rand_strided, same, @@ -54,6 +55,7 @@ from torch._dynamo.testing import ( ) from torch._inductor.utils import fresh_cache from torch.nn import functional as F +from torch.nn.attention.flex_attention import create_block_mask, flex_attention from torch.profiler import profile, ProfilerActivity from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, @@ -7369,6 +7371,67 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor): ) self.assertEqual(explain_output.break_reasons[0].reason, expected_msg) + @parametrize("backend", ["eager", "inductor"]) + def test_issue164247(self, backend: str): + if backend == "inductor" and torch._dynamo.config.dynamic_shapes: + raise unittest.SkipTest( + "Skip only in dynamic-shapes wrapper (known issue #157612)" + ) + + class MixedFakeModeModel(nn.Module): + def __init__(self, dim=64): + super().__init__() + self.dim = dim + self.lin = torch.nn.Linear(64, 64) + + def forward(self, x): + batch_size, seq_len, _ = x.shape + + # Process input first - this creates fake tensors in export's fake mode + processed = self.lin(x) + + # Create some computation that depends on processed tensor + intermediate = processed.sum(dim=-1).detach() # Shape: (batch, seq_len) + + def dynamic_mask_function(batch_idx, head_idx, q_idx, kv_idx): + threshold = intermediate[ + batch_idx, q_idx % seq_len + ] # Access the captured tensor + return (kv_idx <= q_idx) & (threshold > 0) + + block_mask = create_block_mask( + mask_mod=dynamic_mask_function, + B=batch_size, + H=None, + Q_LEN=seq_len, + KV_LEN=seq_len, + device=x.device, + _compile=False, + ) + q = processed.view(batch_size, 1, seq_len, self.dim) + k = processed.view(batch_size, 1, seq_len, self.dim) + v = processed.view(batch_size, 1, seq_len, self.dim) + + out = torch.compile(flex_attention)(q, k, v, block_mask=block_mask) + out = flex_attention(q, k, v, block_mask=block_mask) + + return out + + backend_counter = CompileCounterWithBackend(backend) + model = MixedFakeModeModel() + compiled = torch.compile(model, backend=backend_counter, fullgraph=True) + + if backend == "inductor": + # A known InductorError Issue https://github.com/pytorch/pytorch/issues/157612 + with self.assertRaises(RuntimeError): + compiled(torch.randn(2, 128, 64)) + else: + compiled(torch.randn(2, 128, 64)) + + # One graph, so no graph breaks + self.assertEqual(backend_counter.frame_count, 1) + self.assertEqual(len(backend_counter.graphs), 1) + # https://github.com/pytorch/pytorch/issues/164990 def test_guard_same_frame_fail_message(self): import torch._dynamo.guards as g diff --git a/test/dynamo_expected_failures/TestScript.test_python_frontend b/test/dynamo_expected_failures/TestScript.test_python_frontend deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestScript.test_python_frontend_py3 b/test/dynamo_expected_failures/TestScript.test_python_frontend_py3 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 7d534de073c9..4911ded6e333 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -1320,9 +1320,21 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable): def const_getattr(self, tx, name): if name == "__name__": - return self.fn_name.as_python_constant() + return self.get_name() + if name == "__code__": + return self.get_code() + if name == "__defaults__": + d = getattr(self, "defaults", None) + return d.as_python_constant() if d else None return super().const_getattr(tx, name) + def call_obj_hasattr(self, tx: "InstructionTranslator", name): + if name == "__code__": + return variables.ConstantVariable.create(hasattr(self, "code")) + if name == "__defaults__": + return variables.ConstantVariable.create(hasattr(self, "defaults")) + return super().call_obj_hasattr(tx, name) + def has_self(self): return False diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index a608020f30f3..0a4acdd7a232 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -267,11 +267,20 @@ def _get_mod_type(fn: Callable) -> _ModificationType: considered as a score_mod function. If the function has 4 positional arguments, it is considered as a mask function. """ - num_positional_args = sum( - 1 - for param in inspect.signature(fn).parameters.values() - if param.default is inspect.Parameter.empty - ) + if hasattr(fn, "__code__"): + code = fn.__code__ + num_positional_total = code.co_argcount + defaults = () + if hasattr(fn, "__defaults__"): + defaults = fn.__defaults__ or () + num_defaults = len(defaults) + num_positional_args = num_positional_total - num_defaults + else: + num_positional_args = sum( + 1 + for param in inspect.signature(fn).parameters.values() + if param.default is inspect.Parameter.empty + ) assert num_positional_args == 5 or num_positional_args == 4 if num_positional_args == 5: return _ModificationType.SCORE_MOD