Patch the flex_attention._get_mod_type to not use inspect.signature when computing num_positional_args (an alternative fix for flex attention graph break on create_block_mask) (#164923)

The initial fix for inspect.signature uses not a right approach (https://github.com/pytorch/pytorch/pull/164349#pullrequestreview-3306614010). As @williamwen42 suggests (https://github.com/pytorch/pytorch/pull/164349#issuecomment-3379222885) we can just for now get rid of `inspect.signature` call in flex_attention to resolve this high priority issue (https://github.com/pytorch/pytorch/issues/164247#issuecomment-3378673179). In this PR I did exactly this - limited the scope of fix to just computing `num_positional_args` in `flex_attention._get_mod_type` based on properties returned by `NestedUserFunctionVariable.const_getattr` (some were missing so I added them)

Fixes #164247

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164923
Approved by: https://github.com/williamwen42
This commit is contained in:
jmaczan
2025-10-17 17:44:43 +00:00
committed by PyTorch MergeBot
parent da8517fa63
commit cff1b20771
5 changed files with 90 additions and 6 deletions

View File

@ -46,6 +46,7 @@ from torch._dynamo.backends.debugging import ExplainWithBackend
from torch._dynamo.debug_utils import same_two_models
from torch._dynamo.testing import (
CompileCounter,
CompileCounterWithBackend,
EagerAndRecordGraphs,
rand_strided,
same,
@ -54,6 +55,7 @@ from torch._dynamo.testing import (
)
from torch._inductor.utils import fresh_cache
from torch.nn import functional as F
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from torch.profiler import profile, ProfilerActivity
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION,
@ -7369,6 +7371,67 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
)
self.assertEqual(explain_output.break_reasons[0].reason, expected_msg)
@parametrize("backend", ["eager", "inductor"])
def test_issue164247(self, backend: str):
if backend == "inductor" and torch._dynamo.config.dynamic_shapes:
raise unittest.SkipTest(
"Skip only in dynamic-shapes wrapper (known issue #157612)"
)
class MixedFakeModeModel(nn.Module):
def __init__(self, dim=64):
super().__init__()
self.dim = dim
self.lin = torch.nn.Linear(64, 64)
def forward(self, x):
batch_size, seq_len, _ = x.shape
# Process input first - this creates fake tensors in export's fake mode
processed = self.lin(x)
# Create some computation that depends on processed tensor
intermediate = processed.sum(dim=-1).detach() # Shape: (batch, seq_len)
def dynamic_mask_function(batch_idx, head_idx, q_idx, kv_idx):
threshold = intermediate[
batch_idx, q_idx % seq_len
] # Access the captured tensor
return (kv_idx <= q_idx) & (threshold > 0)
block_mask = create_block_mask(
mask_mod=dynamic_mask_function,
B=batch_size,
H=None,
Q_LEN=seq_len,
KV_LEN=seq_len,
device=x.device,
_compile=False,
)
q = processed.view(batch_size, 1, seq_len, self.dim)
k = processed.view(batch_size, 1, seq_len, self.dim)
v = processed.view(batch_size, 1, seq_len, self.dim)
out = torch.compile(flex_attention)(q, k, v, block_mask=block_mask)
out = flex_attention(q, k, v, block_mask=block_mask)
return out
backend_counter = CompileCounterWithBackend(backend)
model = MixedFakeModeModel()
compiled = torch.compile(model, backend=backend_counter, fullgraph=True)
if backend == "inductor":
# A known InductorError Issue https://github.com/pytorch/pytorch/issues/157612
with self.assertRaises(RuntimeError):
compiled(torch.randn(2, 128, 64))
else:
compiled(torch.randn(2, 128, 64))
# One graph, so no graph breaks
self.assertEqual(backend_counter.frame_count, 1)
self.assertEqual(len(backend_counter.graphs), 1)
# https://github.com/pytorch/pytorch/issues/164990
def test_guard_same_frame_fail_message(self):
import torch._dynamo.guards as g

View File

@ -1320,9 +1320,21 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
def const_getattr(self, tx, name):
if name == "__name__":
return self.fn_name.as_python_constant()
return self.get_name()
if name == "__code__":
return self.get_code()
if name == "__defaults__":
d = getattr(self, "defaults", None)
return d.as_python_constant() if d else None
return super().const_getattr(tx, name)
def call_obj_hasattr(self, tx: "InstructionTranslator", name):
if name == "__code__":
return variables.ConstantVariable.create(hasattr(self, "code"))
if name == "__defaults__":
return variables.ConstantVariable.create(hasattr(self, "defaults"))
return super().call_obj_hasattr(tx, name)
def has_self(self):
return False

View File

@ -267,6 +267,15 @@ def _get_mod_type(fn: Callable) -> _ModificationType:
considered as a score_mod function. If the function has 4 positional arguments, it is
considered as a mask function.
"""
if hasattr(fn, "__code__"):
code = fn.__code__
num_positional_total = code.co_argcount
defaults = ()
if hasattr(fn, "__defaults__"):
defaults = fn.__defaults__ or ()
num_defaults = len(defaults)
num_positional_args = num_positional_total - num_defaults
else:
num_positional_args = sum(
1
for param in inspect.signature(fn).parameters.values()