mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[FlexAttention] Fix output layout (#135882)
We previously only supported the same v_head dim and + qk_head dim. When allowed for different head-dims I accidently kept the same query strides for the output. This PR fixes this bug as well it ensures that we always produce output in the same stride order as the input query. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135882 Approved by: https://github.com/yanboliang, https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
ad2f0e9f81
commit
ae02d663cd
@ -1666,6 +1666,52 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||||||
out = func(query, key, value, block_mask=block_mask)
|
out = func(query, key, value, block_mask=block_mask)
|
||||||
out.sum().backward()
|
out.sum().backward()
|
||||||
|
|
||||||
|
@supported_platform
|
||||||
|
@common_utils.parametrize("mode", ["eager", "inductor"])
|
||||||
|
@common_utils.parametrize(
|
||||||
|
"permute_order",
|
||||||
|
[
|
||||||
|
(0, 1, 2, 3), # Default order
|
||||||
|
(1, 0, 2, 3), # Reverse order
|
||||||
|
(0, 2, 1, 3), # Mixed order
|
||||||
|
(2, 0, 1, 3), # Another mixed order
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@common_utils.parametrize("shape", [(2, 1, 128, 16), (4, 2, 64, 16)])
|
||||||
|
def test_flex_attention_stride_ordering(self, mode, permute_order, shape):
|
||||||
|
from torch._inductor.ir import get_stride_order
|
||||||
|
|
||||||
|
# Setup
|
||||||
|
make_tensor = functools.partial(
|
||||||
|
torch.randn,
|
||||||
|
shape,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float32,
|
||||||
|
requires_grad=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and permute tensors
|
||||||
|
query, key, value = make_tensor(), make_tensor(), make_tensor()
|
||||||
|
query = query.permute(permute_order)
|
||||||
|
key = key.permute(permute_order)
|
||||||
|
value = value.permute(permute_order)
|
||||||
|
|
||||||
|
if mode == "inductor":
|
||||||
|
func = torch.compile(flex_attention, backend=mode, fullgraph=True)
|
||||||
|
else:
|
||||||
|
func = flex_attention
|
||||||
|
|
||||||
|
out = func(query, key, value)
|
||||||
|
|
||||||
|
out_stride_order = get_stride_order(out.stride())
|
||||||
|
query_stride_order = get_stride_order(query.stride())
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
out_stride_order,
|
||||||
|
query_stride_order,
|
||||||
|
f"Stride order mismatch: out {out_stride_order}, query {query_stride_order}",
|
||||||
|
)
|
||||||
|
|
||||||
@supported_platform
|
@supported_platform
|
||||||
@common_utils.parametrize("compile", [True, False])
|
@common_utils.parametrize("compile", [True, False])
|
||||||
def test_fully_masked_out_rows_0_check(self, compile: bool):
|
def test_fully_masked_out_rows_0_check(self, compile: bool):
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# mypy: allow-untyped-decorators
|
# mypy: allow-untyped-decorators
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import math
|
import math
|
||||||
from typing import Any, Callable, Dict, Tuple, Union
|
from typing import Any, Callable, Dict, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
@ -23,6 +23,53 @@ from torch.fx.graph_module import GraphModule
|
|||||||
from torch.overrides import TorchFunctionMode
|
from torch.overrides import TorchFunctionMode
|
||||||
|
|
||||||
|
|
||||||
|
# Duplicate of _inductor/kernel/flex_attention.py to avoid circular import
|
||||||
|
def _construct_strides(
|
||||||
|
sizes: Sequence[int],
|
||||||
|
fill_order: Sequence[int],
|
||||||
|
) -> Sequence[int]:
|
||||||
|
"""From a list of sizes and a fill order, construct the strides of the permuted tensor."""
|
||||||
|
# Initialize strides
|
||||||
|
assert len(sizes) == len(
|
||||||
|
fill_order
|
||||||
|
), "Length of sizes must match the length of the fill order"
|
||||||
|
strides = [0] * len(sizes)
|
||||||
|
|
||||||
|
# Start with stride 1 for the innermost dimension
|
||||||
|
current_stride = 1
|
||||||
|
|
||||||
|
# Iterate through the fill order populating strides
|
||||||
|
for dim in fill_order:
|
||||||
|
strides[dim] = current_stride
|
||||||
|
current_stride *= sizes[dim]
|
||||||
|
|
||||||
|
return strides
|
||||||
|
|
||||||
|
|
||||||
|
def _permute_strides(out: torch.Tensor, query_strides: Tuple[int, ...]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Create a new tensor with the same data and shape as the input,
|
||||||
|
but with strides permuted based on the input tensor's stride order.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out (torch.Tensor): The output tensor of attention.
|
||||||
|
query_strides (List[int]): The stride order of the input query tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: A new tensor with same shape and data as the input,
|
||||||
|
but with strides permuted based on the query tensor's stride order.
|
||||||
|
"""
|
||||||
|
from torch._inductor.ir import get_stride_order, stride_order2fill_order
|
||||||
|
|
||||||
|
stride_order = get_stride_order(query_strides)
|
||||||
|
fill_order = stride_order2fill_order(stride_order)
|
||||||
|
assert out.storage_offset() == 0, "Only support storage_offset == 0"
|
||||||
|
out_strides = _construct_strides(out.shape, fill_order)
|
||||||
|
new_out = out.new_empty(out.shape).as_strided(out.shape, out_strides)
|
||||||
|
new_out.copy_(out)
|
||||||
|
return new_out
|
||||||
|
|
||||||
|
|
||||||
class TransformGetItemToIndex(TorchFunctionMode):
|
class TransformGetItemToIndex(TorchFunctionMode):
|
||||||
# This is needed since we want to support calling
|
# This is needed since we want to support calling
|
||||||
# A[q_idx], where q_idx is a scalar tensor in score_mod.
|
# A[q_idx], where q_idx is a scalar tensor in score_mod.
|
||||||
@ -244,7 +291,7 @@ def sdpa_dense(
|
|||||||
score_mod_other_buffers,
|
score_mod_other_buffers,
|
||||||
mask_mod_other_buffers,
|
mask_mod_other_buffers,
|
||||||
)
|
)
|
||||||
out = out.contiguous()
|
out = _permute_strides(out, query.stride())
|
||||||
return out, lse
|
return out, lse
|
||||||
|
|
||||||
|
|
||||||
@ -432,7 +479,9 @@ def flex_attention_fake_tensor_mode(
|
|||||||
batch_size, num_heads, seq_len_q, dtype=torch.float32
|
batch_size, num_heads, seq_len_q, dtype=torch.float32
|
||||||
)
|
)
|
||||||
out_shape = (batch_size, num_heads, seq_len_q, v_head_dim)
|
out_shape = (batch_size, num_heads, seq_len_q, v_head_dim)
|
||||||
return query.new_empty(out_shape), logsumexp
|
out = query.new_empty(out_shape)
|
||||||
|
out = _permute_strides(out, query.stride())
|
||||||
|
return out, logsumexp
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------- Autograd Implementation ----------------------------
|
# ---------------------------- Autograd Implementation ----------------------------
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
|
|
||||||
@ -17,9 +17,11 @@ from ..ir import (
|
|||||||
ExternKernel,
|
ExternKernel,
|
||||||
FixedLayout,
|
FixedLayout,
|
||||||
FlexibleLayout,
|
FlexibleLayout,
|
||||||
|
get_stride_order,
|
||||||
InputBuffer,
|
InputBuffer,
|
||||||
IRNode,
|
IRNode,
|
||||||
StorageBox,
|
StorageBox,
|
||||||
|
stride_order2fill_order,
|
||||||
Subgraph,
|
Subgraph,
|
||||||
TensorBox,
|
TensorBox,
|
||||||
)
|
)
|
||||||
@ -29,6 +31,29 @@ from ..select_algorithm import autotune_select_algorithm, realize_inputs, Triton
|
|||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
Expr = sympy.Expr
|
||||||
|
|
||||||
|
|
||||||
|
def construct_strides(
|
||||||
|
sizes: Sequence[int],
|
||||||
|
fill_order: Sequence[int],
|
||||||
|
) -> Sequence[int]:
|
||||||
|
"""From a list of sizes and a fill order, construct the strides of the permuted tensor."""
|
||||||
|
# Initialize strides
|
||||||
|
assert len(sizes) == len(
|
||||||
|
fill_order
|
||||||
|
), "Length of sizes must match the length of the fill order"
|
||||||
|
strides = [0] * len(sizes)
|
||||||
|
|
||||||
|
# Start with stride 1 for the innermost dimension
|
||||||
|
current_stride = 1
|
||||||
|
|
||||||
|
# Iterate through the fill order populating strides
|
||||||
|
for dim in fill_order:
|
||||||
|
strides[dim] = current_stride
|
||||||
|
current_stride *= sizes[dim]
|
||||||
|
|
||||||
|
return strides
|
||||||
|
|
||||||
|
|
||||||
def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta):
|
def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta):
|
||||||
@ -761,11 +786,18 @@ def flex_attention(
|
|||||||
# This works because only the last dim differs and we check it is contiguous.
|
# This works because only the last dim differs and we check it is contiguous.
|
||||||
q_strides = query.get_stride()
|
q_strides = query.get_stride()
|
||||||
assert q_strides[-1] == 1, "Query must be contiguous in the last dimension"
|
assert q_strides[-1] == 1, "Query must be contiguous in the last dimension"
|
||||||
|
|
||||||
|
# Construct output layout with strides matching the query.
|
||||||
|
out_size = [B, Hq, seq_len_q, v_head_dim]
|
||||||
|
stride_order = get_stride_order(query.get_stride())
|
||||||
|
fill_order = stride_order2fill_order(stride_order)
|
||||||
|
out_strides = construct_strides(out_size, fill_order)
|
||||||
|
|
||||||
layout = FixedLayout(
|
layout = FixedLayout(
|
||||||
query.get_device(),
|
query.get_device(),
|
||||||
query.get_dtype(),
|
query.get_dtype(),
|
||||||
[B, Hq, seq_len_q, v_head_dim],
|
[B, Hq, seq_len_q, v_head_dim],
|
||||||
query.get_stride(),
|
stride=out_strides,
|
||||||
)
|
)
|
||||||
# see NOTE:[TritonTemplates with multiple outputs]
|
# see NOTE:[TritonTemplates with multiple outputs]
|
||||||
logsumexp_shape = [B, Hq, seq_len_q]
|
logsumexp_shape = [B, Hq, seq_len_q]
|
||||||
|
Reference in New Issue
Block a user