[Flex Attention] Don't compute fill order to compute stride order just to get fill order back (#138376)

Was a bit confusing to read when working on #138354

"computer-assisted proof"
```
import random

def argsort(seq):
    # preserve original order for equal strides
    getter = seq.__getitem__
    a_r = range(len(seq))
    return list(reversed(sorted(a_r, key=getter, reverse=True)))  # noqa: C413

def stride_order2fill_order(order):
    """
    Convert stride order to fill order
    For channel last format,

    stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0]
    """
    lookup = {pos: idx for idx, pos in enumerate(order)}
    fill_order = [lookup[i] for i in range(len(order))]
    return fill_order

def get_stride_order(seq):
    """
    Convert strides to stride order
    """
    sorted_idx: List[int] = argsort(seq)
    out = [0 for _ in range(len(seq))]
    a = sorted_idx.copy()
    for i, elem in enumerate(sorted_idx):
        out[elem] = i
    fillorder = stride_order2fill_order(out)
    assert fillorder == sorted_idx
    return out

for _ in range(1000):
    a = [0, 1, 2, 3]
    random.shuffle(a)
    get_stride_order(a)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138376
Approved by: https://github.com/drisspg
This commit is contained in:
eqy
2024-10-22 18:40:39 +00:00
committed by PyTorch MergeBot
parent 2dab4ccb65
commit c0e8458aab
3 changed files with 13 additions and 8 deletions

View File

@ -57,10 +57,9 @@ def _permute_strides(out: torch.Tensor, query_strides: Tuple[int, ...]) -> torch
torch.Tensor: A new tensor with same shape and data as the input,
but with strides permuted based on the query tensor's stride order.
"""
from torch._inductor.ir import get_stride_order, stride_order2fill_order
from torch._inductor.ir import get_fill_order
stride_order = get_stride_order(query_strides)
fill_order = stride_order2fill_order(stride_order)
fill_order = get_fill_order(query_strides)
assert out.storage_offset() == 0, "Only support storage_offset == 0"
out_strides = _construct_strides(out.shape, fill_order)
new_out = out.new_empty(out.shape).as_strided(out.shape, out_strides)

View File

@ -234,6 +234,14 @@ NHWC_STRIDE_ORDER = [3, 0, 2, 1]
NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1]
def get_fill_order(seq: Sequence[Union[int, torch.SymInt, Expr]]) -> Sequence[int]:
"""
Convert strides to fill order (argsort)
"""
sorted_idx: Sequence[int] = argsort(seq)
return sorted_idx
def stride_order2fill_order(order: Sequence[Union[int, Integer]]) -> Sequence[int]:
"""
Convert stride order to fill order
@ -250,7 +258,7 @@ def get_stride_order(seq: Sequence[Union[int, torch.SymInt, Expr]]) -> Sequence[
"""
Convert strides to stride order
"""
sorted_idx: List[int] = argsort(seq)
sorted_idx: Sequence[int] = get_fill_order(seq)
out = [0 for _ in range(len(seq))]
for i, elem in enumerate(sorted_idx):
out[elem] = i

View File

@ -17,11 +17,10 @@ from ..ir import (
ExternKernel,
FixedLayout,
FlexibleLayout,
get_stride_order,
get_fill_order,
InputBuffer,
IRNode,
StorageBox,
stride_order2fill_order,
Subgraph,
TensorBox,
)
@ -793,8 +792,7 @@ def flex_attention(
# Construct output layout with strides matching the query.
out_size = [B, Hq, seq_len_q, v_head_dim]
stride_order = get_stride_order(query.get_stride())
fill_order = stride_order2fill_order(stride_order)
fill_order = get_fill_order(query.get_stride())
out_strides = construct_strides(out_size, fill_order)
layout = FixedLayout(