Files
pytorch/torch/_inductor/kernel/flex/common.py
2025-09-23 15:37:29 +00:00

347 lines
11 KiB
Python

# mypy: allow-untyped-defs
"""Common utilities and functions for flex attention kernels"""
import math
from collections.abc import Sequence
from pathlib import Path
from typing import Any, Optional, Union
import sympy
import torch
from torch._inductor.virtualized import V
from torch.utils._ordered_set import OrderedSet
from torch.utils._pytree import tree_map, tree_map_only
from ...ir import (
ComputedBuffer,
ExternKernel,
FixedLayout,
FlexibleLayout,
get_fill_order,
InputBuffer,
IRNode,
MutationLayoutSHOULDREMOVE,
Scatter,
ShapeAsConstantBuffer,
StorageBox,
Subgraph,
TensorBox,
)
from ...lowering import (
_full,
check_and_broadcast_indices,
expand,
index_output_size_and_inner_fn,
to_dtype,
)
from ...select_algorithm import realize_inputs
SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]]
def zeros_and_scatter_lowering(shape: list[int], indices, values):
"""To support backwards on captured buffers we register a specific lowering for our specific custom up"""
# Always accumulate into fp32 then cast
grad = _full(0, values.get_device(), torch.float32, shape)
assert isinstance(grad, TensorBox)
grad.realize()
x_size = grad.get_size()
values = to_dtype(values, grad.get_dtype())
indices_loaders = [i.make_loader() if i is not None else None for i in indices]
indices, tensor_indices = check_and_broadcast_indices(indices, grad.get_device())
# We can use the first one since they are all required to be the same size
tensor_size = list(indices[tensor_indices[0]].get_size())
indexed_size = [x_size[i] for i in range(len(indices))]
expected_vals_size, inner_fn = index_output_size_and_inner_fn(
x_size,
indices,
tensor_indices,
tensor_size,
indices_loaders,
indexed_size,
None,
check=True,
)
values = expand(values, expected_vals_size)
device = grad.get_device()
assert device is not None
scatter = Scatter(
device=device,
dtype=grad.get_dtype(),
inner_fn=values.make_loader(),
ranges=expected_vals_size, # iter_ranges,
output_indexer=inner_fn,
scatter_mode="atomic_add",
)
buffer = ComputedBuffer(
name=grad.data.data.name, # type: ignore[attr-defined]
layout=MutationLayoutSHOULDREMOVE(grad),
data=scatter,
)
return buffer
def get_fwd_subgraph_outputs(
subgraph_buffer: SubgraphResults, mask_graph_buffer: SubgraphResults
) -> list[Optional[ComputedBuffer]]:
subgraph_buffer = (
subgraph_buffer if isinstance(subgraph_buffer, Sequence) else [subgraph_buffer]
)
mask_graph_buffer = (
mask_graph_buffer
if isinstance(mask_graph_buffer, Sequence)
else [mask_graph_buffer]
)
return [*subgraph_buffer, *mask_graph_buffer]
def build_subgraph_module_buffer(
args: list[Union[TensorBox, ShapeAsConstantBuffer]],
graph_module: torch.fx.GraphModule,
) -> SubgraphResults:
"""This function's goal is to take in the required args and produce the subgraph buffer
The subgraph buffer is a ComputedBuffer that will be inlined into the triton template
Args:
args: The args that are passed into the subgraph. Contains both fixed and lifted inputs.
subgraph: The Subgraph ir for which to produce the output node
"""
# This one we gotta keep lazy
from ...subgraph_lowering import PointwiseSubgraphLowering
pw_subgraph = PointwiseSubgraphLowering(
graph_module,
root_graph_lowering=V.graph,
allowed_mutations=OrderedSet([torch.ops.flex_lib.zeros_and_scatter.default]),
additional_lowerings={
torch.ops.flex_lib.zeros_and_scatter.default: zeros_and_scatter_lowering
},
)
with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type]
pw_subgraph.run(*args)
def convert_output_node_to_buffer(output_buffer) -> Optional[ComputedBuffer]:
if output_buffer is None:
return None
if isinstance(output_buffer, ComputedBuffer):
# These nodes are coming from the output of zeros_and_scatter
return output_buffer
assert isinstance(output_buffer, TensorBox), (
"The output node for flex attention's subgraph must be a TensorBox, but got: ",
type(output_buffer),
)
assert isinstance(output_buffer.data, StorageBox), (
"The output node for the flex attention subgraph must be a StorageBox, but got: ",
type(output_buffer),
)
device = output_buffer.data.get_device()
assert device is not None
subgraph_buffer = ComputedBuffer(
name=None,
layout=FlexibleLayout(
device=device,
dtype=output_buffer.data.get_dtype(),
size=output_buffer.data.get_size(),
),
data=output_buffer.data.data, # type: ignore[arg-type]
)
return subgraph_buffer
return tree_map(convert_output_node_to_buffer, pw_subgraph.graph_outputs)
def build_subgraph_buffer(
args: list[Union[TensorBox, ShapeAsConstantBuffer]], subgraph: Subgraph
) -> SubgraphResults:
return build_subgraph_module_buffer(args, subgraph.graph_module)
def maybe_realize(args: list[Optional[IRNode]]):
"""Accepts a list of optional IRNodes and returns a list of realized IRNodes"""
return tree_map(
lambda x: (
realize_inputs(x)
if x is not None and not isinstance(x, sympy.Symbol)
else x
),
args,
)
def freeze_irnodes(tree: Any) -> Any:
"""Freeze layouts for every IRNode contained in a pytree."""
if tree is None:
return None
def _freeze(node: IRNode) -> IRNode:
try:
node.freeze_layout()
except NotImplementedError:
pass
return node
return tree_map_only(IRNode, _freeze, tree)
def create_placeholder(
name: str,
dtype: torch.dtype,
device: torch.device,
size: Optional[list[int]] = None,
) -> Union[TensorBox, ShapeAsConstantBuffer]:
"""Creates a placeholder input buffers for producing subgraph_output."""
input_buffer = InputBuffer(
name=name,
layout=FixedLayout(
device,
dtype,
size if size else [],
FlexibleLayout.contiguous_strides(size) if size else [],
),
)
return TensorBox.create(input_buffer)
def construct_strides(
sizes: Sequence[int],
fill_order: Sequence[int],
) -> Sequence[int]:
"""From a list of sizes and a fill order, construct the strides of the permuted tensor."""
# Initialize strides
assert len(sizes) == len(fill_order), (
"Length of sizes must match the length of the fill order"
)
strides = [0] * len(sizes)
# Start with stride 1 for the innermost dimension
current_stride = 1
# Iterate through the fill order populating strides
for dim in fill_order:
strides[dim] = current_stride
current_stride *= sizes[dim]
return strides
def infer_dense_strides(size: Sequence[int], orig_strides: Sequence[int]):
"""This is a mirror of the same function in aten/src/ATen/ExpandUtils.cpp
Args:
size: The size of the output tensor
orig_strides: The strides of the input tensor
Returns:
List[int]: Dense non-overlapping strides that preserve the input tensor's layout permutation.
The returned strides follow the same stride propagation rules as TensorIterator. This matches
The behavior of empty_like()
"""
fill_order = get_fill_order(orig_strides, V.graph.sizevars.shape_env)
return construct_strides(size, fill_order)
def create_indices_fake(x) -> torch.Tensor:
"""Create a fake indices that is used for autotuning."""
size = [V.graph.sizevars.size_hint(i) for i in x.get_size()]
indices = torch.arange(0, size[-1], dtype=x.get_dtype(), device=x.get_device())
indices = indices.expand(size).contiguous()
return indices
def create_num_blocks_fake_generator(sparse_indices):
"""Create a fake num_blocks that is used for autotuning.
The idea here is that we need to create a real tensor with real data
that's representative for benchmarking.
For example, returning all zeros for the `kv_num_blocks` input would mean
that we are computing 0 blocks for each row, which would provide bogus
autotuning results.
In this case, we choose to use min(16, max_block) blocks, because I
(Horace) think it'll probably result in pretty representative performance.
If it's too short then prefetching won't help. If it's too long then
autotuning will take longer for no good reason.
"""
def create_num_blocks_fake(x) -> torch.Tensor:
num_blocks_for_autotuning = V.graph.sizevars.size_hint(sparse_indices.shape[-1])
size = [V.graph.sizevars.size_hint(i) for i in x.get_size()]
return torch.full(
size,
num_blocks_for_autotuning,
dtype=x.get_dtype(),
device=x.get_device(),
)
return create_num_blocks_fake
def contiguous_last_dim(x):
"""Ensure that realized IR node has a contiguous stride in the last dimension."""
strides = x.maybe_get_stride()
if strides and strides[-1] != 1:
contiguous_stride_order = list(reversed(range(len(x.get_size()))))
return ExternKernel.require_stride_order(x, contiguous_stride_order)
return x
def set_head_dim_values(
kernel_options: dict[str, Any], qk_head_dim, v_head_dim, graph_sizevars
):
"""
Mutates kernel options, adding head dimension calculations.
Args:
kernel_options: Dictionary to populate with options
qk_head_dim: Query/Key head dimension
v_head_dim: Value head dimension
graph_sizevars: Graph size variables object with guard_int method
"""
# QK dimensions
qk_head_dim_static = graph_sizevars.guard_int(qk_head_dim)
kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim_static)
kernel_options.setdefault(
"QK_HEAD_DIM_ROUNDED", next_power_of_two(qk_head_dim_static)
)
# V dimensions
v_head_dim_static = graph_sizevars.guard_int(v_head_dim)
kernel_options.setdefault("V_HEAD_DIM", v_head_dim_static)
kernel_options.setdefault(
"V_HEAD_DIM_ROUNDED", next_power_of_two(v_head_dim_static)
)
# Safety flag
kernel_options.setdefault(
"SAFE_HEAD_DIM",
is_power_of_2(qk_head_dim_static) and is_power_of_2(v_head_dim_static),
)
def is_power_of_2(n):
return n != 0 and ((n & (n - 1)) == 0)
def next_power_of_two(n):
if n <= 0:
return 1
return 2 ** math.ceil(math.log2(n))
_TEMPLATE_DIR = Path(__file__).parent / "templates"
def load_template(name: str) -> str:
"""Load a template file and return its content."""
with open(_TEMPLATE_DIR / f"{name}.py.jinja") as f:
return f.read()
# Template strings have been moved to templates/common.py.jinja