mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Flex Attention] Don't compute fill order to compute stride order just to get fill order back (#138376)
Was a bit confusing to read when working on #138354 "computer-assisted proof" ``` import random def argsort(seq): # preserve original order for equal strides getter = seq.__getitem__ a_r = range(len(seq)) return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413 def stride_order2fill_order(order): """ Convert stride order to fill order For channel last format, stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0] """ lookup = {pos: idx for idx, pos in enumerate(order)} fill_order = [lookup[i] for i in range(len(order))] return fill_order def get_stride_order(seq): """ Convert strides to stride order """ sorted_idx: List[int] = argsort(seq) out = [0 for _ in range(len(seq))] a = sorted_idx.copy() for i, elem in enumerate(sorted_idx): out[elem] = i fillorder = stride_order2fill_order(out) assert fillorder == sorted_idx return out for _ in range(1000): a = [0, 1, 2, 3] random.shuffle(a) get_stride_order(a) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/138376 Approved by: https://github.com/drisspg
This commit is contained in:
@ -57,10 +57,9 @@ def _permute_strides(out: torch.Tensor, query_strides: Tuple[int, ...]) -> torch
|
||||
torch.Tensor: A new tensor with same shape and data as the input,
|
||||
but with strides permuted based on the query tensor's stride order.
|
||||
"""
|
||||
from torch._inductor.ir import get_stride_order, stride_order2fill_order
|
||||
from torch._inductor.ir import get_fill_order
|
||||
|
||||
stride_order = get_stride_order(query_strides)
|
||||
fill_order = stride_order2fill_order(stride_order)
|
||||
fill_order = get_fill_order(query_strides)
|
||||
assert out.storage_offset() == 0, "Only support storage_offset == 0"
|
||||
out_strides = _construct_strides(out.shape, fill_order)
|
||||
new_out = out.new_empty(out.shape).as_strided(out.shape, out_strides)
|
||||
|
@ -234,6 +234,14 @@ NHWC_STRIDE_ORDER = [3, 0, 2, 1]
|
||||
NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1]
|
||||
|
||||
|
||||
def get_fill_order(seq: Sequence[Union[int, torch.SymInt, Expr]]) -> Sequence[int]:
|
||||
"""
|
||||
Convert strides to fill order (argsort)
|
||||
"""
|
||||
sorted_idx: Sequence[int] = argsort(seq)
|
||||
return sorted_idx
|
||||
|
||||
|
||||
def stride_order2fill_order(order: Sequence[Union[int, Integer]]) -> Sequence[int]:
|
||||
"""
|
||||
Convert stride order to fill order
|
||||
@ -250,7 +258,7 @@ def get_stride_order(seq: Sequence[Union[int, torch.SymInt, Expr]]) -> Sequence[
|
||||
"""
|
||||
Convert strides to stride order
|
||||
"""
|
||||
sorted_idx: List[int] = argsort(seq)
|
||||
sorted_idx: Sequence[int] = get_fill_order(seq)
|
||||
out = [0 for _ in range(len(seq))]
|
||||
for i, elem in enumerate(sorted_idx):
|
||||
out[elem] = i
|
||||
|
@ -17,11 +17,10 @@ from ..ir import (
|
||||
ExternKernel,
|
||||
FixedLayout,
|
||||
FlexibleLayout,
|
||||
get_stride_order,
|
||||
get_fill_order,
|
||||
InputBuffer,
|
||||
IRNode,
|
||||
StorageBox,
|
||||
stride_order2fill_order,
|
||||
Subgraph,
|
||||
TensorBox,
|
||||
)
|
||||
@ -793,8 +792,7 @@ def flex_attention(
|
||||
|
||||
# Construct output layout with strides matching the query.
|
||||
out_size = [B, Hq, seq_len_q, v_head_dim]
|
||||
stride_order = get_stride_order(query.get_stride())
|
||||
fill_order = stride_order2fill_order(stride_order)
|
||||
fill_order = get_fill_order(query.get_stride())
|
||||
out_strides = construct_strides(out_size, fill_order)
|
||||
|
||||
layout = FixedLayout(
|
||||
|
Reference in New Issue
Block a user