[FlexAttention] Fix output layout (#135882)

We previously only supported the same v_head dim and + qk_head dim. When allowed for different head-dims I accidently kept the same query strides for the output. This PR fixes this bug as well it ensures that we always produce output in the same stride order as the input query.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135882
Approved by: https://github.com/yanboliang, https://github.com/Chillee
This commit is contained in:
drisspg
2024-09-12 19:05:42 -07:00
committed by PyTorch MergeBot
parent ad2f0e9f81
commit ae02d663cd
3 changed files with 132 additions and 5 deletions

View File

@ -1666,6 +1666,52 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
out = func(query, key, value, block_mask=block_mask) out = func(query, key, value, block_mask=block_mask)
out.sum().backward() out.sum().backward()
@supported_platform
@common_utils.parametrize("mode", ["eager", "inductor"])
@common_utils.parametrize(
"permute_order",
[
(0, 1, 2, 3), # Default order
(1, 0, 2, 3), # Reverse order
(0, 2, 1, 3), # Mixed order
(2, 0, 1, 3), # Another mixed order
],
)
@common_utils.parametrize("shape", [(2, 1, 128, 16), (4, 2, 64, 16)])
def test_flex_attention_stride_ordering(self, mode, permute_order, shape):
from torch._inductor.ir import get_stride_order
# Setup
make_tensor = functools.partial(
torch.randn,
shape,
device="cuda",
dtype=torch.float32,
requires_grad=True,
)
# Create and permute tensors
query, key, value = make_tensor(), make_tensor(), make_tensor()
query = query.permute(permute_order)
key = key.permute(permute_order)
value = value.permute(permute_order)
if mode == "inductor":
func = torch.compile(flex_attention, backend=mode, fullgraph=True)
else:
func = flex_attention
out = func(query, key, value)
out_stride_order = get_stride_order(out.stride())
query_stride_order = get_stride_order(query.stride())
self.assertEqual(
out_stride_order,
query_stride_order,
f"Stride order mismatch: out {out_stride_order}, query {query_stride_order}",
)
@supported_platform @supported_platform
@common_utils.parametrize("compile", [True, False]) @common_utils.parametrize("compile", [True, False])
def test_fully_masked_out_rows_0_check(self, compile: bool): def test_fully_masked_out_rows_0_check(self, compile: bool):

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-decorators # mypy: allow-untyped-decorators
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import math import math
from typing import Any, Callable, Dict, Tuple, Union from typing import Any, Callable, Dict, Sequence, Tuple, Union
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
@ -23,6 +23,53 @@ from torch.fx.graph_module import GraphModule
from torch.overrides import TorchFunctionMode from torch.overrides import TorchFunctionMode
# Duplicate of _inductor/kernel/flex_attention.py to avoid circular import
def _construct_strides(
sizes: Sequence[int],
fill_order: Sequence[int],
) -> Sequence[int]:
"""From a list of sizes and a fill order, construct the strides of the permuted tensor."""
# Initialize strides
assert len(sizes) == len(
fill_order
), "Length of sizes must match the length of the fill order"
strides = [0] * len(sizes)
# Start with stride 1 for the innermost dimension
current_stride = 1
# Iterate through the fill order populating strides
for dim in fill_order:
strides[dim] = current_stride
current_stride *= sizes[dim]
return strides
def _permute_strides(out: torch.Tensor, query_strides: Tuple[int, ...]) -> torch.Tensor:
"""
Create a new tensor with the same data and shape as the input,
but with strides permuted based on the input tensor's stride order.
Args:
out (torch.Tensor): The output tensor of attention.
query_strides (List[int]): The stride order of the input query tensor
Returns:
torch.Tensor: A new tensor with same shape and data as the input,
but with strides permuted based on the query tensor's stride order.
"""
from torch._inductor.ir import get_stride_order, stride_order2fill_order
stride_order = get_stride_order(query_strides)
fill_order = stride_order2fill_order(stride_order)
assert out.storage_offset() == 0, "Only support storage_offset == 0"
out_strides = _construct_strides(out.shape, fill_order)
new_out = out.new_empty(out.shape).as_strided(out.shape, out_strides)
new_out.copy_(out)
return new_out
class TransformGetItemToIndex(TorchFunctionMode): class TransformGetItemToIndex(TorchFunctionMode):
# This is needed since we want to support calling # This is needed since we want to support calling
# A[q_idx], where q_idx is a scalar tensor in score_mod. # A[q_idx], where q_idx is a scalar tensor in score_mod.
@ -244,7 +291,7 @@ def sdpa_dense(
score_mod_other_buffers, score_mod_other_buffers,
mask_mod_other_buffers, mask_mod_other_buffers,
) )
out = out.contiguous() out = _permute_strides(out, query.stride())
return out, lse return out, lse
@ -432,7 +479,9 @@ def flex_attention_fake_tensor_mode(
batch_size, num_heads, seq_len_q, dtype=torch.float32 batch_size, num_heads, seq_len_q, dtype=torch.float32
) )
out_shape = (batch_size, num_heads, seq_len_q, v_head_dim) out_shape = (batch_size, num_heads, seq_len_q, v_head_dim)
return query.new_empty(out_shape), logsumexp out = query.new_empty(out_shape)
out = _permute_strides(out, query.stride())
return out, logsumexp
# ---------------------------- Autograd Implementation ---------------------------- # ---------------------------- Autograd Implementation ----------------------------

View File

@ -3,7 +3,7 @@
import logging import logging
import math import math
from typing import Any, List, Optional, Tuple from typing import Any, List, Optional, Sequence, Tuple
import sympy import sympy
@ -17,9 +17,11 @@ from ..ir import (
ExternKernel, ExternKernel,
FixedLayout, FixedLayout,
FlexibleLayout, FlexibleLayout,
get_stride_order,
InputBuffer, InputBuffer,
IRNode, IRNode,
StorageBox, StorageBox,
stride_order2fill_order,
Subgraph, Subgraph,
TensorBox, TensorBox,
) )
@ -29,6 +31,29 @@ from ..select_algorithm import autotune_select_algorithm, realize_inputs, Triton
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
aten = torch.ops.aten aten = torch.ops.aten
Expr = sympy.Expr
def construct_strides(
sizes: Sequence[int],
fill_order: Sequence[int],
) -> Sequence[int]:
"""From a list of sizes and a fill order, construct the strides of the permuted tensor."""
# Initialize strides
assert len(sizes) == len(
fill_order
), "Length of sizes must match the length of the fill order"
strides = [0] * len(sizes)
# Start with stride 1 for the innermost dimension
current_stride = 1
# Iterate through the fill order populating strides
for dim in fill_order:
strides[dim] = current_stride
current_stride *= sizes[dim]
return strides
def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta): def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta):
@ -761,11 +786,18 @@ def flex_attention(
# This works because only the last dim differs and we check it is contiguous. # This works because only the last dim differs and we check it is contiguous.
q_strides = query.get_stride() q_strides = query.get_stride()
assert q_strides[-1] == 1, "Query must be contiguous in the last dimension" assert q_strides[-1] == 1, "Query must be contiguous in the last dimension"
# Construct output layout with strides matching the query.
out_size = [B, Hq, seq_len_q, v_head_dim]
stride_order = get_stride_order(query.get_stride())
fill_order = stride_order2fill_order(stride_order)
out_strides = construct_strides(out_size, fill_order)
layout = FixedLayout( layout = FixedLayout(
query.get_device(), query.get_device(),
query.get_dtype(), query.get_dtype(),
[B, Hq, seq_len_q, v_head_dim], [B, Hq, seq_len_q, v_head_dim],
query.get_stride(), stride=out_strides,
) )
# see NOTE:[TritonTemplates with multiple outputs] # see NOTE:[TritonTemplates with multiple outputs]
logsumexp_shape = [B, Hq, seq_len_q] logsumexp_shape = [B, Hq, seq_len_q]