[FlexAttention] fixing learnable bias assertion error in inductor (#161170)

Users encountered unexpected behaviour when using FlexAttention with learnable biases, including assertion errors (#157677)

We traced the root cause to the registration of subgraph buffers—this caused inconsistencies in the naming and ultimately incorrect retrieval later on. This problem only arose if the model was compiled as a whole (ie using @torch.compile) since only then would there be naming conflicts.

In this PR, we register the buffers with the base graph to solve this issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161170
Approved by: https://github.com/drisspg
This commit is contained in:
Angel Li
2025-08-23 06:24:19 +00:00
committed by PyTorch MergeBot
parent 6443ea337d
commit 3a4140bf8e
3 changed files with 51 additions and 8 deletions

View File

@ -5997,6 +5997,56 @@ class TestLearnableBiases(InductorTestCase):
],
)
@skip_on_cpu
@common_utils.parametrize(
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
)
@torch.compile
def test_learnable_bias_global_compiled(self, device, params):
batch_size = 1
num_heads = 1
seq_len = 128
head_dim = 16
d_model = num_heads * head_dim
query = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
key = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
value = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
out_proj = nn.Linear(d_model, d_model, device=device)
query.requires_grad = True
key.requires_grad = True
value.requires_grad = True
bias = torch.randn(
batch_size,
num_heads,
seq_len,
seq_len,
device=device,
requires_grad=True,
)
def bias_mod(score, b, h, q_idx, kv_idx):
return score + bias[b, h, q_idx, kv_idx]
out = flex_attention(
query=query,
key=key,
value=value,
score_mod=bias_mod,
)
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
attn_output = out_proj(out)
random_target = torch.randn(batch_size, seq_len, d_model, device=device)
loss = torch.nn.functional.mse_loss(attn_output, random_target)
loss.backward()
assert bias.grad, "No gradient computed for bias"
assert torch.any(bias.grad != 0), "Gradient for bias is 0"
@skip_on_cpu
@common_utils.parametrize(
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"

View File

@ -125,12 +125,6 @@ def build_subgraph_module_buffer(
with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type]
pw_subgraph.run(*args)
# Since we are allowing mutations/buffer creation, we need to register any fresh buffers
# creating during the pointwise subgraph lowering
if len(pw_subgraph.buffers) > 0:
for buffer in pw_subgraph.buffers:
V.graph.register_buffer(buffer)
def convert_output_node_to_buffer(output_buffer) -> Optional[ComputedBuffer]:
if output_buffer is None:
return None

View File

@ -87,8 +87,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str:
if self._approved_mutator():
name = self.qualify_name(f"buf{len(self.buffers)}")
self.buffers.append(buffer)
name = self.root_graph.register_buffer(buffer, set_name=set_name)
return name
else:
raise SubgraphLoweringException(