mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Patch the flex_attention._get_mod_type to not use inspect.signature when computing num_positional_args (an alternative fix for flex attention graph break on create_block_mask) (#164923)
The initial fix for inspect.signature uses not a right approach (https://github.com/pytorch/pytorch/pull/164349#pullrequestreview-3306614010). As @williamwen42 suggests (https://github.com/pytorch/pytorch/pull/164349#issuecomment-3379222885) we can just for now get rid of `inspect.signature` call in flex_attention to resolve this high priority issue (https://github.com/pytorch/pytorch/issues/164247#issuecomment-3378673179). In this PR I did exactly this - limited the scope of fix to just computing `num_positional_args` in `flex_attention._get_mod_type` based on properties returned by `NestedUserFunctionVariable.const_getattr` (some were missing so I added them) Fixes #164247 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164923 Approved by: https://github.com/williamwen42
This commit is contained in:
committed by
PyTorch MergeBot
parent
da8517fa63
commit
cff1b20771
@ -46,6 +46,7 @@ from torch._dynamo.backends.debugging import ExplainWithBackend
|
||||
from torch._dynamo.debug_utils import same_two_models
|
||||
from torch._dynamo.testing import (
|
||||
CompileCounter,
|
||||
CompileCounterWithBackend,
|
||||
EagerAndRecordGraphs,
|
||||
rand_strided,
|
||||
same,
|
||||
@ -54,6 +55,7 @@ from torch._dynamo.testing import (
|
||||
)
|
||||
from torch._inductor.utils import fresh_cache
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
@ -7369,6 +7371,67 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
)
|
||||
self.assertEqual(explain_output.break_reasons[0].reason, expected_msg)
|
||||
|
||||
@parametrize("backend", ["eager", "inductor"])
|
||||
def test_issue164247(self, backend: str):
|
||||
if backend == "inductor" and torch._dynamo.config.dynamic_shapes:
|
||||
raise unittest.SkipTest(
|
||||
"Skip only in dynamic-shapes wrapper (known issue #157612)"
|
||||
)
|
||||
|
||||
class MixedFakeModeModel(nn.Module):
|
||||
def __init__(self, dim=64):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.lin = torch.nn.Linear(64, 64)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, seq_len, _ = x.shape
|
||||
|
||||
# Process input first - this creates fake tensors in export's fake mode
|
||||
processed = self.lin(x)
|
||||
|
||||
# Create some computation that depends on processed tensor
|
||||
intermediate = processed.sum(dim=-1).detach() # Shape: (batch, seq_len)
|
||||
|
||||
def dynamic_mask_function(batch_idx, head_idx, q_idx, kv_idx):
|
||||
threshold = intermediate[
|
||||
batch_idx, q_idx % seq_len
|
||||
] # Access the captured tensor
|
||||
return (kv_idx <= q_idx) & (threshold > 0)
|
||||
|
||||
block_mask = create_block_mask(
|
||||
mask_mod=dynamic_mask_function,
|
||||
B=batch_size,
|
||||
H=None,
|
||||
Q_LEN=seq_len,
|
||||
KV_LEN=seq_len,
|
||||
device=x.device,
|
||||
_compile=False,
|
||||
)
|
||||
q = processed.view(batch_size, 1, seq_len, self.dim)
|
||||
k = processed.view(batch_size, 1, seq_len, self.dim)
|
||||
v = processed.view(batch_size, 1, seq_len, self.dim)
|
||||
|
||||
out = torch.compile(flex_attention)(q, k, v, block_mask=block_mask)
|
||||
out = flex_attention(q, k, v, block_mask=block_mask)
|
||||
|
||||
return out
|
||||
|
||||
backend_counter = CompileCounterWithBackend(backend)
|
||||
model = MixedFakeModeModel()
|
||||
compiled = torch.compile(model, backend=backend_counter, fullgraph=True)
|
||||
|
||||
if backend == "inductor":
|
||||
# A known InductorError Issue https://github.com/pytorch/pytorch/issues/157612
|
||||
with self.assertRaises(RuntimeError):
|
||||
compiled(torch.randn(2, 128, 64))
|
||||
else:
|
||||
compiled(torch.randn(2, 128, 64))
|
||||
|
||||
# One graph, so no graph breaks
|
||||
self.assertEqual(backend_counter.frame_count, 1)
|
||||
self.assertEqual(len(backend_counter.graphs), 1)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/164990
|
||||
def test_guard_same_frame_fail_message(self):
|
||||
import torch._dynamo.guards as g
|
||||
|
@ -1320,9 +1320,21 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
|
||||
|
||||
def const_getattr(self, tx, name):
|
||||
if name == "__name__":
|
||||
return self.fn_name.as_python_constant()
|
||||
return self.get_name()
|
||||
if name == "__code__":
|
||||
return self.get_code()
|
||||
if name == "__defaults__":
|
||||
d = getattr(self, "defaults", None)
|
||||
return d.as_python_constant() if d else None
|
||||
return super().const_getattr(tx, name)
|
||||
|
||||
def call_obj_hasattr(self, tx: "InstructionTranslator", name):
|
||||
if name == "__code__":
|
||||
return variables.ConstantVariable.create(hasattr(self, "code"))
|
||||
if name == "__defaults__":
|
||||
return variables.ConstantVariable.create(hasattr(self, "defaults"))
|
||||
return super().call_obj_hasattr(tx, name)
|
||||
|
||||
def has_self(self):
|
||||
return False
|
||||
|
||||
|
@ -267,11 +267,20 @@ def _get_mod_type(fn: Callable) -> _ModificationType:
|
||||
considered as a score_mod function. If the function has 4 positional arguments, it is
|
||||
considered as a mask function.
|
||||
"""
|
||||
num_positional_args = sum(
|
||||
1
|
||||
for param in inspect.signature(fn).parameters.values()
|
||||
if param.default is inspect.Parameter.empty
|
||||
)
|
||||
if hasattr(fn, "__code__"):
|
||||
code = fn.__code__
|
||||
num_positional_total = code.co_argcount
|
||||
defaults = ()
|
||||
if hasattr(fn, "__defaults__"):
|
||||
defaults = fn.__defaults__ or ()
|
||||
num_defaults = len(defaults)
|
||||
num_positional_args = num_positional_total - num_defaults
|
||||
else:
|
||||
num_positional_args = sum(
|
||||
1
|
||||
for param in inspect.signature(fn).parameters.values()
|
||||
if param.default is inspect.Parameter.empty
|
||||
)
|
||||
assert num_positional_args == 5 or num_positional_args == 4
|
||||
if num_positional_args == 5:
|
||||
return _ModificationType.SCORE_MOD
|
||||
|
Reference in New Issue
Block a user