mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	# Summary More code motion, tldr is that install 'Better Jinja' in vscode and now you can get highlighting Before <img width="776" height="926" alt="Screenshot 2025-08-11 at 2 41 08 PM" src="https://github.com/user-attachments/assets/10868b31-f8ac-4cf5-99fe-19b8789ce06b" /> After: <img width="1184" height="1299" alt="Screenshot 2025-08-11 at 2 40 27 PM" src="https://github.com/user-attachments/assets/45203765-589e-4d76-8196-d895a2f2fbf6" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/160357 Approved by: https://github.com/eellison
		
			
				
	
	
		
			337 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			337 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# mypy: allow-untyped-defs
 | 
						|
"""Common utilities and functions for flex attention kernels"""
 | 
						|
 | 
						|
import math
 | 
						|
from collections.abc import Sequence
 | 
						|
from pathlib import Path
 | 
						|
from typing import Any, Optional, Union
 | 
						|
 | 
						|
import sympy
 | 
						|
 | 
						|
import torch
 | 
						|
from torch._inductor.virtualized import V
 | 
						|
from torch.utils._ordered_set import OrderedSet
 | 
						|
from torch.utils._pytree import tree_map
 | 
						|
 | 
						|
from ...ir import (
 | 
						|
    ComputedBuffer,
 | 
						|
    ExternKernel,
 | 
						|
    FixedLayout,
 | 
						|
    FlexibleLayout,
 | 
						|
    get_fill_order,
 | 
						|
    InputBuffer,
 | 
						|
    IRNode,
 | 
						|
    MutationLayoutSHOULDREMOVE,
 | 
						|
    Scatter,
 | 
						|
    ShapeAsConstantBuffer,
 | 
						|
    StorageBox,
 | 
						|
    Subgraph,
 | 
						|
    TensorBox,
 | 
						|
)
 | 
						|
from ...lowering import (
 | 
						|
    _full,
 | 
						|
    check_and_broadcast_indices,
 | 
						|
    expand,
 | 
						|
    index_output_size_and_inner_fn,
 | 
						|
    to_dtype,
 | 
						|
)
 | 
						|
from ...select_algorithm import realize_inputs
 | 
						|
 | 
						|
 | 
						|
SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]]
 | 
						|
 | 
						|
 | 
						|
def zeros_and_scatter_lowering(shape: list[int], indices, values):
 | 
						|
    """To support backwards on captured buffers we register a specific lowering for our specific custom up"""
 | 
						|
    # Always accumulate into fp32 then cast
 | 
						|
    grad = _full(0, values.get_device(), torch.float32, shape)
 | 
						|
    assert isinstance(grad, TensorBox)
 | 
						|
    grad.realize()
 | 
						|
    x_size = grad.get_size()
 | 
						|
    values = to_dtype(values, grad.get_dtype())
 | 
						|
    indices_loaders = [i.make_loader() if i is not None else None for i in indices]
 | 
						|
    indices, tensor_indices = check_and_broadcast_indices(indices, grad.get_device())
 | 
						|
    # We can use the first one since they are all required to be the same size
 | 
						|
    tensor_size = list(indices[tensor_indices[0]].get_size())
 | 
						|
    indexed_size = [x_size[i] for i in range(len(indices))]
 | 
						|
 | 
						|
    expected_vals_size, inner_fn = index_output_size_and_inner_fn(
 | 
						|
        x_size,
 | 
						|
        indices,
 | 
						|
        tensor_indices,
 | 
						|
        tensor_size,
 | 
						|
        indices_loaders,
 | 
						|
        indexed_size,
 | 
						|
        None,
 | 
						|
        check=True,
 | 
						|
    )
 | 
						|
 | 
						|
    values = expand(values, expected_vals_size)
 | 
						|
    device = grad.get_device()
 | 
						|
    assert device is not None
 | 
						|
    scatter = Scatter(
 | 
						|
        device=device,
 | 
						|
        dtype=grad.get_dtype(),
 | 
						|
        inner_fn=values.make_loader(),
 | 
						|
        ranges=expected_vals_size,  # iter_ranges,
 | 
						|
        output_indexer=inner_fn,
 | 
						|
        scatter_mode="atomic_add",
 | 
						|
    )
 | 
						|
 | 
						|
    buffer = ComputedBuffer(
 | 
						|
        name=grad.data.data.name,  # type: ignore[attr-defined]
 | 
						|
        layout=MutationLayoutSHOULDREMOVE(grad),
 | 
						|
        data=scatter,
 | 
						|
    )
 | 
						|
    return buffer
 | 
						|
 | 
						|
 | 
						|
def get_fwd_subgraph_outputs(
 | 
						|
    subgraph_buffer: SubgraphResults, mask_graph_buffer: SubgraphResults
 | 
						|
) -> list[Optional[ComputedBuffer]]:
 | 
						|
    subgraph_buffer = (
 | 
						|
        subgraph_buffer if isinstance(subgraph_buffer, Sequence) else [subgraph_buffer]
 | 
						|
    )
 | 
						|
    mask_graph_buffer = (
 | 
						|
        mask_graph_buffer
 | 
						|
        if isinstance(mask_graph_buffer, Sequence)
 | 
						|
        else [mask_graph_buffer]
 | 
						|
    )
 | 
						|
    return [*subgraph_buffer, *mask_graph_buffer]
 | 
						|
 | 
						|
 | 
						|
def build_subgraph_module_buffer(
 | 
						|
    args: list[Union[TensorBox, ShapeAsConstantBuffer]],
 | 
						|
    graph_module: torch.fx.GraphModule,
 | 
						|
) -> SubgraphResults:
 | 
						|
    """This function's goal is to take in the required args and produce the subgraph buffer
 | 
						|
    The subgraph buffer is a ComputedBuffer that will be inlined into the triton template
 | 
						|
 | 
						|
    Args:
 | 
						|
        args: The args that are passed into the subgraph. Contains both fixed and lifted inputs.
 | 
						|
        subgraph: The Subgraph ir for which to produce the output node
 | 
						|
    """
 | 
						|
    # This one we gotta keep lazy
 | 
						|
    from ...subgraph_lowering import PointwiseSubgraphLowering
 | 
						|
 | 
						|
    pw_subgraph = PointwiseSubgraphLowering(
 | 
						|
        graph_module,
 | 
						|
        root_graph_lowering=V.graph,
 | 
						|
        allowed_mutations=OrderedSet([torch.ops.flex_lib.zeros_and_scatter.default]),
 | 
						|
        additional_lowerings={
 | 
						|
            torch.ops.flex_lib.zeros_and_scatter.default: zeros_and_scatter_lowering
 | 
						|
        },
 | 
						|
    )
 | 
						|
    with V.set_graph_handler(pw_subgraph):  # type: ignore[arg-type]
 | 
						|
        pw_subgraph.run(*args)
 | 
						|
 | 
						|
    # Since we are allowing mutations/buffer creation, we need to register any fresh buffers
 | 
						|
    # creating during the pointwise subgraph lowering
 | 
						|
    if len(pw_subgraph.buffers) > 0:
 | 
						|
        for buffer in pw_subgraph.buffers:
 | 
						|
            V.graph.register_buffer(buffer)
 | 
						|
 | 
						|
    def convert_output_node_to_buffer(output_buffer) -> Optional[ComputedBuffer]:
 | 
						|
        if output_buffer is None:
 | 
						|
            return None
 | 
						|
        if isinstance(output_buffer, ComputedBuffer):
 | 
						|
            # These nodes are coming from the output of zeros_and_scatter
 | 
						|
            return output_buffer
 | 
						|
        assert isinstance(output_buffer, TensorBox), (
 | 
						|
            "The output node for flex attention's subgraph must be a TensorBox, but got: ",
 | 
						|
            type(output_buffer),
 | 
						|
        )
 | 
						|
        assert isinstance(output_buffer.data, StorageBox), (
 | 
						|
            "The output node for the flex attention subgraph must be a StorageBox, but got: ",
 | 
						|
            type(output_buffer),
 | 
						|
        )
 | 
						|
        device = output_buffer.data.get_device()
 | 
						|
        assert device is not None
 | 
						|
        subgraph_buffer = ComputedBuffer(
 | 
						|
            name=None,
 | 
						|
            layout=FlexibleLayout(
 | 
						|
                device=device,
 | 
						|
                dtype=output_buffer.data.get_dtype(),
 | 
						|
                size=output_buffer.data.get_size(),
 | 
						|
            ),
 | 
						|
            data=output_buffer.data.data,  # type: ignore[arg-type]
 | 
						|
        )
 | 
						|
        return subgraph_buffer
 | 
						|
 | 
						|
    return tree_map(convert_output_node_to_buffer, pw_subgraph.graph_outputs)
 | 
						|
 | 
						|
 | 
						|
def build_subgraph_buffer(
 | 
						|
    args: list[Union[TensorBox, ShapeAsConstantBuffer]], subgraph: Subgraph
 | 
						|
) -> SubgraphResults:
 | 
						|
    return build_subgraph_module_buffer(args, subgraph.graph_module)
 | 
						|
 | 
						|
 | 
						|
def maybe_realize(args: list[Optional[IRNode]]):
 | 
						|
    """Accepts a list of optional IRNodes and returns a list of realized IRNodes"""
 | 
						|
    return tree_map(
 | 
						|
        lambda x: (
 | 
						|
            realize_inputs(x)
 | 
						|
            if x is not None and not isinstance(x, sympy.Symbol)
 | 
						|
            else x
 | 
						|
        ),
 | 
						|
        args,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def create_placeholder(
 | 
						|
    name: str,
 | 
						|
    dtype: torch.dtype,
 | 
						|
    device: torch.device,
 | 
						|
    size: Optional[list[int]] = None,
 | 
						|
) -> Union[TensorBox, ShapeAsConstantBuffer]:
 | 
						|
    """Creates a placeholder input buffers for producing subgraph_output."""
 | 
						|
    input_buffer = InputBuffer(
 | 
						|
        name=name,
 | 
						|
        layout=FixedLayout(
 | 
						|
            device,
 | 
						|
            dtype,
 | 
						|
            size if size else [],
 | 
						|
            FlexibleLayout.contiguous_strides(size) if size else [],
 | 
						|
        ),
 | 
						|
    )
 | 
						|
    return TensorBox.create(input_buffer)
 | 
						|
 | 
						|
 | 
						|
def construct_strides(
 | 
						|
    sizes: Sequence[int],
 | 
						|
    fill_order: Sequence[int],
 | 
						|
) -> Sequence[int]:
 | 
						|
    """From a list of sizes and a fill order, construct the strides of the permuted tensor."""
 | 
						|
    # Initialize strides
 | 
						|
    assert len(sizes) == len(fill_order), (
 | 
						|
        "Length of sizes must match the length of the fill order"
 | 
						|
    )
 | 
						|
    strides = [0] * len(sizes)
 | 
						|
 | 
						|
    # Start with stride 1 for the innermost dimension
 | 
						|
    current_stride = 1
 | 
						|
 | 
						|
    # Iterate through the fill order populating strides
 | 
						|
    for dim in fill_order:
 | 
						|
        strides[dim] = current_stride
 | 
						|
        current_stride *= sizes[dim]
 | 
						|
 | 
						|
    return strides
 | 
						|
 | 
						|
 | 
						|
def infer_dense_strides(size: Sequence[int], orig_strides: Sequence[int]):
 | 
						|
    """This is a mirror of the same function in aten/src/ATen/ExpandUtils.cpp
 | 
						|
 | 
						|
    Args:
 | 
						|
        size: The size of the output tensor
 | 
						|
        orig_strides: The strides of the input tensor
 | 
						|
    Returns:
 | 
						|
        List[int]: Dense non-overlapping strides that preserve the input tensor's layout permutation.
 | 
						|
        The returned strides follow the same stride propagation rules as TensorIterator. This matches
 | 
						|
        The behavior of empty_like()
 | 
						|
    """
 | 
						|
    fill_order = get_fill_order(orig_strides, V.graph.sizevars.shape_env)
 | 
						|
    return construct_strides(size, fill_order)
 | 
						|
 | 
						|
 | 
						|
def create_indices_fake(x) -> torch.Tensor:
 | 
						|
    """Create a fake indices that is used for autotuning."""
 | 
						|
    size = [V.graph.sizevars.size_hint(i) for i in x.get_size()]
 | 
						|
    indices = torch.arange(0, size[-1], dtype=x.get_dtype(), device=x.get_device())
 | 
						|
    indices = indices.expand(size).contiguous()
 | 
						|
    return indices
 | 
						|
 | 
						|
 | 
						|
def create_num_blocks_fake_generator(sparse_indices):
 | 
						|
    """Create a fake num_blocks that is used for autotuning.
 | 
						|
 | 
						|
    The idea here is that we need to create a real tensor with real data
 | 
						|
    that's representative for benchmarking.
 | 
						|
    For example, returning all zeros for the `kv_num_blocks` input would mean
 | 
						|
    that we are computing 0 blocks for each row, which would provide bogus
 | 
						|
    autotuning results.
 | 
						|
 | 
						|
    In this case, we choose to use min(16, max_block) blocks, because I
 | 
						|
    (Horace) think it'll probably result in pretty representative performance.
 | 
						|
    If it's too short then prefetching won't help. If it's too long then
 | 
						|
    autotuning will take longer for no good reason.
 | 
						|
    """
 | 
						|
 | 
						|
    def create_num_blocks_fake(x) -> torch.Tensor:
 | 
						|
        num_blocks_for_autotuning = V.graph.sizevars.size_hint(sparse_indices.shape[-1])
 | 
						|
        size = [V.graph.sizevars.size_hint(i) for i in x.get_size()]
 | 
						|
        return torch.full(
 | 
						|
            size,
 | 
						|
            num_blocks_for_autotuning,
 | 
						|
            dtype=x.get_dtype(),
 | 
						|
            device=x.get_device(),
 | 
						|
        )
 | 
						|
 | 
						|
    return create_num_blocks_fake
 | 
						|
 | 
						|
 | 
						|
def contiguous_last_dim(x):
 | 
						|
    """Ensure that realized IR node has a contiguous stride in the last dimension."""
 | 
						|
    strides = x.maybe_get_stride()
 | 
						|
    if strides and strides[-1] != 1:
 | 
						|
        contiguous_stride_order = list(reversed(range(len(x.get_size()))))
 | 
						|
        return ExternKernel.require_stride_order(x, contiguous_stride_order)
 | 
						|
    return x
 | 
						|
 | 
						|
 | 
						|
def set_head_dim_values(
 | 
						|
    kernel_options: dict[str, Any], qk_head_dim, v_head_dim, graph_sizevars
 | 
						|
):
 | 
						|
    """
 | 
						|
    Mutates kernel options, adding head dimension calculations.
 | 
						|
 | 
						|
    Args:
 | 
						|
        kernel_options: Dictionary to populate with options
 | 
						|
        qk_head_dim: Query/Key head dimension
 | 
						|
        v_head_dim: Value head dimension
 | 
						|
        graph_sizevars: Graph size variables object with guard_int method
 | 
						|
 | 
						|
    """
 | 
						|
    # QK dimensions
 | 
						|
    qk_head_dim_static = graph_sizevars.guard_int(qk_head_dim)
 | 
						|
    kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim_static)
 | 
						|
    kernel_options.setdefault(
 | 
						|
        "QK_HEAD_DIM_ROUNDED", next_power_of_two(qk_head_dim_static)
 | 
						|
    )
 | 
						|
 | 
						|
    # V dimensions
 | 
						|
    v_head_dim_static = graph_sizevars.guard_int(v_head_dim)
 | 
						|
    kernel_options.setdefault("V_HEAD_DIM", v_head_dim_static)
 | 
						|
    kernel_options.setdefault(
 | 
						|
        "V_HEAD_DIM_ROUNDED", next_power_of_two(v_head_dim_static)
 | 
						|
    )
 | 
						|
 | 
						|
    # Safety flag
 | 
						|
    kernel_options.setdefault(
 | 
						|
        "SAFE_HEAD_DIM",
 | 
						|
        is_power_of_2(qk_head_dim_static) and is_power_of_2(v_head_dim_static),
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def is_power_of_2(n):
 | 
						|
    return n != 0 and ((n & (n - 1)) == 0)
 | 
						|
 | 
						|
 | 
						|
def next_power_of_two(n):
 | 
						|
    if n <= 0:
 | 
						|
        return 1
 | 
						|
    return 2 ** math.ceil(math.log2(n))
 | 
						|
 | 
						|
 | 
						|
_TEMPLATE_DIR = Path(__file__).parent / "templates"
 | 
						|
 | 
						|
 | 
						|
def load_template(name: str) -> str:
 | 
						|
    """Load a template file and return its content."""
 | 
						|
    with open(_TEMPLATE_DIR / f"{name}.py.jinja") as f:
 | 
						|
        return f.read()
 | 
						|
 | 
						|
 | 
						|
# Template strings have been moved to templates/common.py.jinja
 |