mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[FlexAttention] fixing learnable bias assertion error in inductor (#161170)
Users encountered unexpected behaviour when using FlexAttention with learnable biases, including assertion errors (#157677) We traced the root cause to the registration of subgraph buffers—this caused inconsistencies in the naming and ultimately incorrect retrieval later on. This problem only arose if the model was compiled as a whole (ie using @torch.compile) since only then would there be naming conflicts. In this PR, we register the buffers with the base graph to solve this issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161170 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
6443ea337d
commit
3a4140bf8e
@ -5997,6 +5997,56 @@ class TestLearnableBiases(InductorTestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@skip_on_cpu
|
||||
@common_utils.parametrize(
|
||||
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
|
||||
)
|
||||
@torch.compile
|
||||
def test_learnable_bias_global_compiled(self, device, params):
|
||||
batch_size = 1
|
||||
num_heads = 1
|
||||
seq_len = 128
|
||||
head_dim = 16
|
||||
d_model = num_heads * head_dim
|
||||
|
||||
query = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
|
||||
key = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
|
||||
value = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
|
||||
|
||||
out_proj = nn.Linear(d_model, d_model, device=device)
|
||||
|
||||
query.requires_grad = True
|
||||
key.requires_grad = True
|
||||
value.requires_grad = True
|
||||
|
||||
bias = torch.randn(
|
||||
batch_size,
|
||||
num_heads,
|
||||
seq_len,
|
||||
seq_len,
|
||||
device=device,
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
def bias_mod(score, b, h, q_idx, kv_idx):
|
||||
return score + bias[b, h, q_idx, kv_idx]
|
||||
|
||||
out = flex_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
score_mod=bias_mod,
|
||||
)
|
||||
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
|
||||
|
||||
attn_output = out_proj(out)
|
||||
random_target = torch.randn(batch_size, seq_len, d_model, device=device)
|
||||
loss = torch.nn.functional.mse_loss(attn_output, random_target)
|
||||
loss.backward()
|
||||
|
||||
assert bias.grad, "No gradient computed for bias"
|
||||
assert torch.any(bias.grad != 0), "Gradient for bias is 0"
|
||||
|
||||
@skip_on_cpu
|
||||
@common_utils.parametrize(
|
||||
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
|
||||
|
@ -125,12 +125,6 @@ def build_subgraph_module_buffer(
|
||||
with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type]
|
||||
pw_subgraph.run(*args)
|
||||
|
||||
# Since we are allowing mutations/buffer creation, we need to register any fresh buffers
|
||||
# creating during the pointwise subgraph lowering
|
||||
if len(pw_subgraph.buffers) > 0:
|
||||
for buffer in pw_subgraph.buffers:
|
||||
V.graph.register_buffer(buffer)
|
||||
|
||||
def convert_output_node_to_buffer(output_buffer) -> Optional[ComputedBuffer]:
|
||||
if output_buffer is None:
|
||||
return None
|
||||
|
@ -87,8 +87,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
|
||||
|
||||
def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str:
|
||||
if self._approved_mutator():
|
||||
name = self.qualify_name(f"buf{len(self.buffers)}")
|
||||
self.buffers.append(buffer)
|
||||
name = self.root_graph.register_buffer(buffer, set_name=set_name)
|
||||
return name
|
||||
else:
|
||||
raise SubgraphLoweringException(
|
||||
|
Reference in New Issue
Block a user