mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Loads from fixed inputs (#162031)
## TODO Check on multi indices ```Python @cute.jit def score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): in_ptr4 = buffers[0] tmp0 = tSrS_ssa tmp1 = b_idx tmp2 = h_idx tmp3 = cute.make_fragment(1, cutlass.Int32) tmp4 = tmp3.store(32*tmp1 + tmp2) tmp5 = cute.make_fragment(1, cutlass.BFloat16) tmp6 = tmp3[0] tmp7 = tmp5[0] = (in_ptr4[tmp6]) tmp8 = (tmp5.load()).to(cutlass.Float32) tmp9 = (tmp0 + tmp8) tSrS_ssa = tmp9 return tSrS_ssa ``` I dont think that ``` tmp4 = tmp3.store(32*tmp1 + tmp2) tmp5 = cute.make_fragment(1, cutlass.BFloat16) tmp6 = tmp3[0] tmp7 = tmp5[0] = (in_ptr4[tmp6] ``` is right since this tmp6 value will be larger than the actual index dim int his case its B -> see if its possible to 1d index Pull Request resolved: https://github.com/pytorch/pytorch/pull/162031 Approved by: https://github.com/v0i0 ghstack dependencies: #161118
This commit is contained in:
committed by
PyTorch MergeBot
parent
0a2cde2f06
commit
0747d95994
@ -27,6 +27,74 @@ def _rel_bias(score, _b, _h, token_q, token_kv):
|
||||
return score + (token_q - token_kv)
|
||||
|
||||
|
||||
def create_alibi_learned(num_heads=4, dtype=torch.float16):
|
||||
"""ALiBi with learned per-head slopes (tests tensor loading)."""
|
||||
slopes = torch.exp2(-torch.linspace(1, 8, num_heads, device="cuda", dtype=dtype))
|
||||
|
||||
def alibi_score_mod(score, b, h, q_idx, kv_idx):
|
||||
bias = (kv_idx - q_idx) * slopes[h]
|
||||
return score + bias
|
||||
|
||||
return alibi_score_mod
|
||||
|
||||
|
||||
def create_pos_bias_table(seq_len=512, dtype=torch.float16):
|
||||
"""Relative position bias table (tests computed indexing)."""
|
||||
max_len = seq_len
|
||||
table = torch.randn(2 * max_len - 1, device="cuda", dtype=dtype) * 0.1
|
||||
|
||||
def pos_bias_mod(score, b, h, q_idx, kv_idx):
|
||||
rel_pos = kv_idx - q_idx + max_len - 1
|
||||
bias = table[rel_pos]
|
||||
return score + bias
|
||||
|
||||
return pos_bias_mod
|
||||
|
||||
|
||||
def create_head_scale(num_heads=4, dtype=torch.float16):
|
||||
"""Per-head scaling factors (tests multiplication with tensor loading)."""
|
||||
scales = torch.rand(num_heads, device="cuda", dtype=dtype) + 0.5
|
||||
|
||||
def head_scale_mod(score, b, h, q_idx, kv_idx):
|
||||
return score * scales[h]
|
||||
|
||||
return head_scale_mod
|
||||
|
||||
|
||||
def create_batch_bias(batch_size=2, dtype=torch.float16):
|
||||
"""Per-batch bias (tests batch indexing)."""
|
||||
bias = torch.randn(batch_size, device="cuda", dtype=dtype) * 0.1
|
||||
|
||||
def batch_bias_mod(score, b, h, q_idx, kv_idx):
|
||||
return score + bias[b]
|
||||
|
||||
return batch_bias_mod
|
||||
|
||||
|
||||
def create_batch_head_bias(batch_size=2, num_heads=4, dtype=torch.float16):
|
||||
"""Per-batch-head bias matrix (tests 2D indexing with batch + head)."""
|
||||
bias_matrix = torch.randn(batch_size, num_heads, device="cuda", dtype=dtype) * 0.5
|
||||
|
||||
def batch_head_mod(score, b, h, q_idx, kv_idx):
|
||||
bias = bias_matrix[b, h]
|
||||
return score + bias
|
||||
|
||||
return batch_head_mod
|
||||
|
||||
|
||||
def create_dual_buffer_bias(num_heads=4, seq_len=512, dtype=torch.float16):
|
||||
"""Dual buffer loading (tests loading from 2 separate tensors)."""
|
||||
head_bias = torch.randn(num_heads, device="cuda", dtype=dtype) * 0.2
|
||||
pos_scale = torch.arange(seq_len, device="cuda", dtype=dtype)
|
||||
|
||||
def dual_buffer_mod(score, b, h, q_idx, kv_idx):
|
||||
head_component = head_bias[h]
|
||||
pos_component = pos_scale[q_idx] * 0.01
|
||||
return score + head_component + pos_component
|
||||
|
||||
return dual_buffer_mod
|
||||
|
||||
|
||||
def create_test_tensors(
|
||||
batch_size=2, num_heads=4, seq_len=512, dim=64, dtype=torch.float16, device="cuda"
|
||||
):
|
||||
@ -142,6 +210,72 @@ class TestFlexFlash(InductorTestCase):
|
||||
f"Flash attention kernel unexpectedly found when force_flash=False. Kernels: {prof_result['kernel_names']}",
|
||||
)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_flash_attention_with_alibi_learned(self, device, dtype):
|
||||
"""Test flash attention with ALiBi learned slopes (tensor loading)."""
|
||||
q, k, v = create_test_tensors(dtype=dtype, device=device)
|
||||
score_mod = create_alibi_learned(num_heads=4, dtype=dtype)
|
||||
flash_vs_triton(q, k, v, score_mod=score_mod)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_flash_attention_with_pos_bias_table(self, device, dtype):
|
||||
"""Test flash attention with position bias table (tensor loading)."""
|
||||
q, k, v = create_test_tensors(dtype=dtype, device=device)
|
||||
score_mod = create_pos_bias_table(seq_len=512, dtype=dtype)
|
||||
flash_vs_triton(q, k, v, score_mod=score_mod)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_flash_attention_with_head_scale(self, device, dtype):
|
||||
"""Test flash attention with head scaling (tensor loading)."""
|
||||
q, k, v = create_test_tensors(dtype=dtype, device=device)
|
||||
score_mod = create_head_scale(num_heads=4, dtype=dtype)
|
||||
flash_vs_triton(q, k, v, score_mod=score_mod)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_flash_attention_with_batch_bias(self, device, dtype):
|
||||
"""Test flash attention with batch bias (tensor loading)."""
|
||||
q, k, v = create_test_tensors(dtype=dtype, device=device)
|
||||
score_mod = create_batch_bias(batch_size=2, dtype=dtype)
|
||||
flash_vs_triton(q, k, v, score_mod=score_mod)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_flash_attention_with_batch_head_bias(self, device, dtype):
|
||||
"""Test flash attention with batch-head bias matrix (tensor loading)."""
|
||||
q, k, v = create_test_tensors(dtype=dtype, device=device)
|
||||
score_mod = create_batch_head_bias(batch_size=2, num_heads=4, dtype=dtype)
|
||||
flash_vs_triton(q, k, v, score_mod=score_mod)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_flash_attention_with_dual_buffer_bias(self, device, dtype):
|
||||
"""Test flash attention with dual buffer loading (tensor loading)."""
|
||||
q, k, v = create_test_tensors(dtype=dtype, device=device)
|
||||
score_mod = create_dual_buffer_bias(num_heads=4, seq_len=512, dtype=dtype)
|
||||
flash_vs_triton(q, k, v, score_mod=score_mod)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_force_flash_error_with_requires_grad(self, device, dtype):
|
||||
"""Test that force_flash=True raises error when tensor requires gradients."""
|
||||
q, k, v = create_test_tensors(dtype=dtype, device=device)
|
||||
|
||||
# Create a score mod with requires_grad tensor
|
||||
bias = torch.randn(4, device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
def score_mod_with_grad(score, b, h, q_idx, kv_idx):
|
||||
return score + bias[h]
|
||||
|
||||
compiled_fn = torch.compile(flex_attention)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"force_flash=True but flash attention cannot be used.*require gradients",
|
||||
):
|
||||
compiled_fn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
score_mod=score_mod_with_grad,
|
||||
kernel_options={"force_flash": True},
|
||||
)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestFlexFlash, globals(), only_for="cuda")
|
||||
|
||||
|
@ -8,6 +8,7 @@ from typing import Any, Callable, Optional
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.common import (
|
||||
CSE,
|
||||
CSEVariable,
|
||||
@ -20,6 +21,7 @@ from torch._inductor.ops_handler import StoreMode
|
||||
from torch._inductor.utils import OrderedSet
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
from ...utils import sympy_index_symbol
|
||||
from .cutedsl_op_overrides import CuteDSLOpOverrides
|
||||
|
||||
|
||||
@ -111,6 +113,15 @@ class CuteDSLTemplateKernel(Kernel):
|
||||
|
||||
self.cse = CSE(name_prefix="tmp")
|
||||
|
||||
# Track all tensor buffers added during modification processing
|
||||
self.collected_tensor_buffers: list[str] = []
|
||||
|
||||
def kexpr(self, expr: sympy.Expr) -> str:
|
||||
"""Convert sympy expression to CuteDSL string representation."""
|
||||
# For CuteDSL, we use standard Python string conversion
|
||||
# since CuteDSL uses Python syntax for expressions
|
||||
return str(expr)
|
||||
|
||||
def gen_imports(self) -> str:
|
||||
"""Generate common imports for CuteDSL templates."""
|
||||
imports = IndentedBuffer()
|
||||
@ -143,6 +154,8 @@ class CuteDSLTemplateKernel(Kernel):
|
||||
"def_kernel": self.def_kernel,
|
||||
"gen_defines": lambda: self.gen_defines(**kwargs),
|
||||
"get_output": self.get_output,
|
||||
"get_tensor_buffers": self.get_tensor_buffers,
|
||||
"unpack_buffers": self.unpack_buffers,
|
||||
"modification": self.modification,
|
||||
}
|
||||
|
||||
@ -258,6 +271,31 @@ class CuteDSLTemplateKernel(Kernel):
|
||||
raise ValueError(f"Output buffer '{buf_name}' not found in args")
|
||||
return output
|
||||
|
||||
def get_tensor_buffers(self):
|
||||
"""Get list of tensor buffer names that were collected during modifications."""
|
||||
return self.collected_tensor_buffers
|
||||
|
||||
def unpack_buffers(self):
|
||||
"""Generate buffer unpacking code via render hook."""
|
||||
|
||||
def hook():
|
||||
tensor_buffers = self.get_tensor_buffers()
|
||||
if not tensor_buffers:
|
||||
return ""
|
||||
|
||||
# Generate unpacking assignments: in_ptr4 = buffers[0], etc.
|
||||
unpacking_lines = []
|
||||
for i, buffer_name in enumerate(tensor_buffers):
|
||||
unpacking_lines.append(f"{buffer_name} = buffers[{i}]")
|
||||
|
||||
return "\n ".join(unpacking_lines)
|
||||
|
||||
# Register the hook and return placeholder
|
||||
placeholder = "<UNPACK_BUFFERS>"
|
||||
assert placeholder not in self.render_hooks
|
||||
self.render_hooks[placeholder] = hook
|
||||
return placeholder
|
||||
|
||||
def call_kernel(self, name: str, node=None):
|
||||
"""Call the kernel function. Simplified version of TritonTemplateKernel.call_kernel."""
|
||||
wrapper = V.graph.wrapper_code
|
||||
@ -323,6 +361,9 @@ class CuteDSLTemplateKernel(Kernel):
|
||||
"Side-effect only modifications not yet supported for CuteDSL"
|
||||
)
|
||||
|
||||
# Add Buffers that were added during modification
|
||||
self.collected_tensor_buffers.extend(modification_handler.tensor_buffers)
|
||||
|
||||
return self.body.getvalue()
|
||||
|
||||
|
||||
@ -350,6 +391,8 @@ class ModificationWrapperCuteDSL(V.WrapperHandler): # type: ignore[name-defined
|
||||
self.kernel = kernel
|
||||
self.fixed_inputs = fixed_inputs
|
||||
self.mask = mask
|
||||
# Track tensor buffers that get added during modification processing
|
||||
self.tensor_buffers: list[str] = []
|
||||
|
||||
def _get_input_dtype(self, name: str) -> torch.dtype:
|
||||
"""Get the dtype for an input from the kernel's named_input_nodes."""
|
||||
@ -361,9 +404,83 @@ class ModificationWrapperCuteDSL(V.WrapperHandler): # type: ignore[name-defined
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
"""Handle loading from tensor or fixed(template args) input for CuteDSL."""
|
||||
if name not in self.fixed_inputs:
|
||||
raise NotImplementedError(
|
||||
"Tensor loading not yet supported for CuteDSL - only fixed input substitution"
|
||||
index_str = self._process_indexing(index)
|
||||
var = self._add_kernel_input(name)
|
||||
buffer = V.graph.get_buffer(name)
|
||||
var_dtype = buffer.dtype
|
||||
|
||||
# Get the CuteDSL dtype mapping
|
||||
cute_dtype = CuteDSLOpOverrides.TORCH_TO_CUTE_DTYPE.get(
|
||||
var_dtype, "cutlass.Float32"
|
||||
)
|
||||
|
||||
# NB
|
||||
# This assumes single-value loads which is not generally the case but is a workaround
|
||||
# since we don't have gather support yet. We do loads in non-SSA form then convert
|
||||
# back to SSA form for any remaining operations over the loaded values.
|
||||
#
|
||||
# Pattern:
|
||||
# index_frag = cute.make_fragment(1, cutlass.Int32)
|
||||
# index_frag.store(index)
|
||||
# val_frag = cute.make_fragment(1, dtype)
|
||||
# index = index_frag[0]
|
||||
# val_frag[0] = tensor[index]
|
||||
# result = val_frag.load()
|
||||
|
||||
index_frag = self.kernel.cse.generate(
|
||||
self.kernel.body,
|
||||
"cute.make_fragment(1, cutlass.Int32)",
|
||||
dtype=torch.int32,
|
||||
bounds=ValueRanges.unknown(),
|
||||
)
|
||||
|
||||
self.kernel.cse.generate(
|
||||
self.kernel.body,
|
||||
f"{index_frag}.store({index_str})",
|
||||
dtype=torch.int32,
|
||||
bounds=ValueRanges.unknown(),
|
||||
)
|
||||
|
||||
val_frag = self.kernel.cse.generate(
|
||||
self.kernel.body,
|
||||
f"cute.make_fragment(1, {cute_dtype})",
|
||||
dtype=var_dtype,
|
||||
bounds=ValueRanges.unknown(),
|
||||
)
|
||||
|
||||
index_var = self.kernel.cse.generate(
|
||||
self.kernel.body,
|
||||
f"{index_frag}[0]",
|
||||
dtype=torch.int32,
|
||||
bounds=ValueRanges.unknown(),
|
||||
)
|
||||
|
||||
self.kernel.cse.generate(
|
||||
self.kernel.body,
|
||||
f"{val_frag}[0] = ({var}[{index_var}])",
|
||||
dtype=var_dtype,
|
||||
bounds=ValueRanges.unknown(),
|
||||
)
|
||||
|
||||
final_expr = f"{val_frag}.load()"
|
||||
|
||||
# Handle upcast to fp32 if needed
|
||||
if (
|
||||
var_dtype in (torch.float16, torch.bfloat16)
|
||||
and config.triton.codegen_upcast_to_fp32
|
||||
):
|
||||
# Apply dtype conversion after fragment load
|
||||
final_expr = f"({final_expr}).to(cutlass.Float32)"
|
||||
var_dtype = torch.float32
|
||||
|
||||
out = self.kernel.cse.generate(
|
||||
self.kernel.body,
|
||||
final_expr,
|
||||
dtype=var_dtype,
|
||||
bounds=ValueRanges.unknown(),
|
||||
)
|
||||
return out
|
||||
|
||||
value = self.fixed_inputs[name]
|
||||
dtype = self._get_input_dtype(name)
|
||||
|
||||
@ -374,7 +491,7 @@ class ModificationWrapperCuteDSL(V.WrapperHandler): # type: ignore[name-defined
|
||||
|
||||
def indirect_indexing(self, index_var: str, size, check, wrap_neg=True):
|
||||
"""Convert index variable to symbolic form."""
|
||||
raise NotImplementedError("Indirect indexing not supported")
|
||||
return sympy_index_symbol(str(index_var))
|
||||
|
||||
def store(
|
||||
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
|
||||
@ -385,12 +502,17 @@ class ModificationWrapperCuteDSL(V.WrapperHandler): # type: ignore[name-defined
|
||||
|
||||
def _add_kernel_input(self, name: str):
|
||||
"""Add name as input to kernel and return input ref."""
|
||||
return self.kernel.args.input(name)
|
||||
# Get the remapped name that will be used in the kernel
|
||||
remapped_name = self.kernel.args.input(name)
|
||||
# Track the remapped name for later collection
|
||||
if remapped_name not in self.tensor_buffers:
|
||||
self.tensor_buffers.append(remapped_name)
|
||||
return remapped_name
|
||||
|
||||
def _process_indexing(self, index):
|
||||
"""Process and rename indexing, adding symbols as kernel inputs."""
|
||||
# Convert sympy expression to string representation for CuteDSL
|
||||
return str(index) # Simplified for now
|
||||
renamed = self.kernel.rename_indexing(index)
|
||||
return self.kernel.kexpr(renamed)
|
||||
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
try:
|
||||
|
@ -55,17 +55,18 @@ def input_buffers_require_grads(graph_module):
|
||||
|
||||
|
||||
def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool):
|
||||
"""Check if the flex graphs are trivial"""
|
||||
"""Check if the flex graphs are compatible with Flash Attention."""
|
||||
graph = graph_module.graph
|
||||
nodes = list(graph.nodes)
|
||||
# Check if it's just placeholder -> output
|
||||
placeholders = [n for n in nodes if n.op == "placeholder"]
|
||||
output = [n for n in nodes if n.op == "output"]
|
||||
assert len(output) == 1, "Got graph w/ multiple outputs"
|
||||
output_val = output[0].args[0]
|
||||
|
||||
if is_score_graph:
|
||||
# Make sure we dont have any captures
|
||||
return len(placeholders) == 5
|
||||
if input_buffers_require_grads(graph_module):
|
||||
return False
|
||||
return True # party on garth
|
||||
# mask mod graph is empty if we have 4 inputs and full_default output
|
||||
return len(placeholders) == 4 and output_val.target == torch.ops.aten.full.default
|
||||
|
||||
@ -146,8 +147,6 @@ def create_flex_flash_attention_kernel(
|
||||
v_head_dim = value.get_size()[-1]
|
||||
device = query.get_device()
|
||||
dtype = query.get_dtype()
|
||||
|
||||
# Ensure device is not None
|
||||
assert device is not None, "Device must be specified"
|
||||
|
||||
# Match stride pattern from query tensor
|
||||
@ -179,7 +178,6 @@ def create_flex_flash_attention_kernel(
|
||||
|
||||
choices: list[Any] = []
|
||||
causal = kernel_options.get("causal", False)
|
||||
|
||||
assert flash_attention_cutedsl_template is not None
|
||||
error = flash_attention_cutedsl_template.maybe_append_choice(
|
||||
choices,
|
||||
|
@ -8,7 +8,8 @@
|
||||
v_transposed = V.transpose(1, 2)
|
||||
|
||||
@cute.jit
|
||||
def score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx):
|
||||
def score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers):
|
||||
{{unpack_buffers()}}
|
||||
{{ modification(
|
||||
subgraph_number=0,
|
||||
output_name="tSrS_ssa",
|
||||
@ -25,6 +26,15 @@
|
||||
output = {{get_output()}}
|
||||
output_transposed = output.transpose(1, 2)
|
||||
|
||||
# Collect any additional tensor buffers that were added during modifications
|
||||
{% set tensor_buffers = get_tensor_buffers() -%}
|
||||
{% if tensor_buffers -%}
|
||||
buffers = [{% for buffer in tensor_buffers %}{{buffer}}{% if not loop.last %}, {% endif %}{% endfor %}]
|
||||
buffers = list(buffers)
|
||||
{% else -%}
|
||||
buffers = None
|
||||
{% endif -%}
|
||||
|
||||
# Out and LSE filled inplace
|
||||
_flash_attn_fwd(
|
||||
q_transposed,
|
||||
@ -35,5 +45,6 @@
|
||||
return_lse=True,
|
||||
score_mod=score_mod,
|
||||
out=output_transposed,
|
||||
lse=LOGSUMEXP
|
||||
lse=LOGSUMEXP,
|
||||
buffers=buffers
|
||||
)
|
Reference in New Issue
Block a user