Add support for capturing tensors with score_mod (#124444)

```
import torch
from torch import nn
import torch.nn.functional as F
import torch._inductor.config as config
# torch.set_default_device('cuda')

import torch
from torch.nn.attention._templated_attention import _templated_attention as templated_attention
from triton.testing import do_bench
from torch.nn.attention import SDPBackend, sdpa_kernel

index = torch.ops.aten
torch.manual_seed(0)

B = 16
H = 16
S = 2048
D = 64

head_scale = torch.randn(H, device='cuda')
def alibi(score, batch, head, token_q, token_kv):
    return score + torch.ops.aten.index(head_scale, [head]) * (token_q - token_kv)
bias = torch.randn(H, S, S, dtype=torch.float16, device='cuda')

query = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
key = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
value = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)

compiled = torch.compile(templated_attention)
out = compiled(query, key, value, score_mod=alibi)
out2 = templated_attention(query, key, value,score_mod=alibi)
print((out - out2).abs().mean())
assert (out - out2).abs().mean() < 1e-3
print("Flash (no mask): ", do_bench(lambda: F.scaled_dot_product_attention(query, key, value)))
print("Flash (mask): ", do_bench(lambda: F.scaled_dot_product_attention(query, key, value, attn_mask=bias)))
print("flexattention: ", do_bench(lambda: compiled(query, key, value, score_mod=alibi)))
```
<img width="324" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/18c175d0-2720-4dfd-8747-85b8a8f609f5">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124444
Approved by: https://github.com/jansel, https://github.com/drisspg
This commit is contained in:
chilli
2024-04-22 12:12:42 -07:00
committed by PyTorch MergeBot
parent 0792ceab4b
commit 7c253a7776
7 changed files with 214 additions and 67 deletions

View File

@ -4,7 +4,7 @@ import functools
from collections import namedtuple
from typing import Callable
from unittest import expectedFailure, skipUnless
from unittest import skip, skipUnless
from unittest.mock import patch
import torch
@ -28,6 +28,8 @@ supported_platform = skipUnless(
Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
torch.set_float32_matmul_precision("high")
index = torch.ops.aten.index
def create_attention(score_mod):
return functools.partial(_templated_attention, score_mod=score_mod)
@ -39,6 +41,8 @@ test_dtypes = (
else [torch.float16, torch.float32]
)
test_dtypes_fast = [torch.float16]
# TODO float16 was causing ERRORs for tests on ROCm
# See https://github.com/pytorch/pytorch/issues/123531
if common_utils.TEST_WITH_ROCM:
@ -53,13 +57,19 @@ def _causal_mod(score, b, h, token_q, token_kv):
return torch.where(token_q >= token_kv, score, float("-inf"))
B = 4
H = 8
S = 2048
D = 64
class TestTemplatedSDPA(InductorTestCase):
def run_test(self, score_mod: Callable, dtype: torch.dtype = torch.float16):
sdpa_partial = create_attention(score_mod)
compiled_sdpa = torch.compile(sdpa_partial)
q = torch.randn((4, 8, 2048, 64), dtype=dtype, device="cuda")
k = torch.randn((4, 8, 2048, 64), dtype=dtype, device="cuda")
v = torch.randn((4, 8, 2048, 64), dtype=dtype, device="cuda")
q = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out = sdpa_partial(
q.to(torch.float64), k.to(torch.float64), v.to(torch.float64)
)
@ -147,23 +157,116 @@ class TestTemplatedSDPA(InductorTestCase):
self.run_test(composed_score_mod, dtype)
# TODO We are currently not capturing free variables in the closure correctly
@expectedFailure
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_captured_buffers(self, dtype: torch.dtype):
head_offset = torch.rand(8, device="cuda", dtype=dtype)
head_offset = torch.rand(H, device="cuda", dtype=dtype)
def score_mod(score, b, h, m, n):
return score + head_offset[h]
return score + index(head_offset, [h])
self.run_test(score_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_seq_masking(self, dtype):
seq_idx = torch.zeros(S, device="cuda", dtype=torch.bool)
seq_idx[S // 2 :] = 1
def seq_mask_mod(score, b, h, q, kv):
return torch.where(
index(seq_idx, [q]) == index(seq_idx, [kv]), score, float("-inf")
)
self.run_test(seq_mask_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_load_from_bias_seq_only(self, dtype):
bias = torch.randn(S, S, device="cuda", dtype=dtype)
def bias_mod(score, b, h, q, kv):
return score + index(bias, [q, kv])
self.run_test(bias_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_load_from_bias_seq_batch(self, dtype):
bias = torch.randn(B, S, S, device="cuda", dtype=dtype)
def bias_mod(score, b, h, q, kv):
return score + index(bias, [b, q, kv])
self.run_test(bias_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_load_from_bias_head_seq_batch(self, dtype):
bias = torch.randn(B, H, S, S, device="cuda", dtype=dtype)
def bias_mod(score, b, h, q, kv):
return score + index(bias, [b, h, q, kv])
self.run_test(bias_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_load_rel_bias(self, dtype):
rel_bias = torch.randn(2 * S, device="cuda", dtype=dtype)
def bias_mod(score, b, h, q, kv):
return score + index(rel_bias, [(q - kv) + S])
self.run_test(bias_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_dependent_causal_bidirectional(self, dtype):
num_bidirectional = torch.randint(0, S, (B,), device="cuda", dtype=torch.int32)
def bias_mod(score, b, h, q, kv):
causal_attention = q >= kv
cur_num_bidirectional = index(num_bidirectional, (b,))
bidirectional_attention_on_video = (q <= cur_num_bidirectional) & (
kv <= cur_num_bidirectional
)
return torch.where(
bidirectional_attention_on_video | causal_attention,
score,
-float("inf"),
)
self.run_test(bias_mod, dtype)
@supported_platform
@skip("Triton bug ") # https://github.com/pytorch/pytorch/issues/124571
@common_utils.parametrize("dtype", test_dtypes)
def test_njt_causal(self, dtype):
offsets = torch.tensor(
[0, 1024, 1024 + 512, S], device="cuda", dtype=torch.int32
)
seq_idx = torch.zeros(S, device="cuda", dtype=torch.int32)
for idx in range(len(offsets) - 1):
seq_idx[offsets[idx] : offsets[idx + 1]] = idx
def create_njt_wrapper(orig_score_mod, offsets, seq_idx):
def njt_score_mod(qk, b, h, q, kv):
q_nested = q - index(offsets, [index(seq_idx, [q])])
kv_nested = kv - index(offsets, [index(seq_idx, [kv])])
return orig_score_mod(qk, b, h, q_nested, kv_nested)
return njt_score_mod
causal_njt = create_njt_wrapper(_causal_mod, offsets, seq_idx)
self.run_test(causal_njt, dtype)
@supported_platform
def test_backwards_fails(self):
make_tensor = functools.partial(
torch.randn,
(4, 8, 2048, 64),
(B, H, S, D),
dtype=torch.float32,
device="cuda",
requires_grad=True,
@ -177,9 +280,9 @@ class TestTemplatedSDPA(InductorTestCase):
@supported_platform
def test_mixed_dtypes_fails(self):
query = torch.randn((1, 1, 2048, 64), dtype=torch.float32, device="cuda")
key = torch.randn((1, 1, 2048, 64), dtype=torch.float16, device="cuda")
value = torch.randn((1, 1, 2048, 64), dtype=torch.float16, device="cuda")
query = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda")
key = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda")
value = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda")
with self.assertRaisesRegex(
ValueError, "Expected query, key, and value to have the same dtype"
):
@ -201,6 +304,21 @@ class TestTemplatedSDPA(InductorTestCase):
self.run_test(score_mod)
@supported_platform
@patch.object(torch._inductor.config, "max_autotune", True)
def test_max_autotune_with_captured(self):
head_scale = torch.randn(H, device="cuda")
batch_scale = torch.randn(B, device="cuda")
tok_scale = torch.randn(S, device="cuda")
def bias_mod(score, batch, head, token_q, token_kv):
score = score + index(tok_scale, [token_q])
score = score + index(batch_scale, [batch])
score = score + index(head_scale, [head])
return score
self.run_test(bias_mod)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("score_mod", [_identity_mod, _causal_mod])
@ -211,7 +329,7 @@ class TestTemplatedSDPA(InductorTestCase):
make_tensor = functools.partial(
torch.randn,
(4, 8, 2048, 64),
(B, H, S, D),
dtype=dtype,
device="cuda",
requires_grad=True,
@ -253,7 +371,7 @@ class TestTemplatedSDPA(InductorTestCase):
def test_logsumexp_only_return(self):
make_tensor = functools.partial(
torch.randn,
(4, 8, 2048, 64),
(B, H, S, D),
dtype=torch.float32,
device="cuda",
requires_grad=True,
@ -274,7 +392,7 @@ class TestTemplatedSDPA(InductorTestCase):
def test_logsumexp_is_not_fused(self):
make_tensor = functools.partial(
torch.randn,
(4, 8, 2048, 64),
(B, H, S, D),
dtype=torch.float32,
device="cuda",
requires_grad=True,

View File

@ -1535,12 +1535,10 @@ class TemplatedAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
) -> "VariableTracker":
from .builder import wrap_fx_proxy
query, key, value, score_mod, *other_buffers = self.normalize_to_args(
args, kwargs
)
query, key, value, score_mod = self.normalize_to_args(args, kwargs)
p_args, p_kwargs = self.create_wrapped_node(tx, query, score_mod)
proxied_args = [query, key, value, *other_buffers]
proxied_args = [query, key, value]
# Store the invocation as a call
# Norm_kwargs contains the score_function and we dont want to proxy this because

View File

@ -60,7 +60,7 @@ def math_attention(
"""
assert len(other_buffers) == 0, "Other buffers are not yet supported."
scores = query @ key.transpose(-2, -1)
scores = (query @ key.transpose(-2, -1)).to(dtype=torch.float32)
b = torch.arange(0, scores.size(0), device=scores.device)
h = torch.arange(0, scores.size(1), device=scores.device)
@ -179,9 +179,11 @@ def templated_attention_functionalize(
assert isinstance(other_buffers_unwrapped, tuple)
assert all(isinstance(item, torch.Tensor) for item in other_buffers_unwrapped)
example_vals = [torch.zeros((), dtype=query.dtype)] + [
torch.zeros((), dtype=torch.int) for _ in range(4)
]
example_vals = (
[torch.zeros((), dtype=query.dtype)]
+ [torch.zeros((), dtype=torch.int) for _ in range(4)]
+ list(other_buffers_unwrapped)
)
with ctx.redispatch_to_next() as m:
functional_score_mod = ctx.functionalize(score_mod)
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch

View File

@ -3412,22 +3412,14 @@ class TritonScheduling(BaseScheduling):
buffer_names.update(node.used_buffer_names())
# Get buffers objects
def _get_buffer(name: str) -> Union[ir.Buffer, ir.TensorBox]:
if name in V.graph.name_to_buffer:
return V.graph.name_to_buffer[name]
elif name in V.graph.graph_inputs:
return V.graph.graph_inputs[name]
elif name in V.graph.constants:
data = V.graph.constants[name]
return ir.ConstantBuffer(
name,
ir.FixedLayout(
data.device, data.dtype, *V.graph.static_sizes_strides(data)
),
)
raise RuntimeError(f"Failed to find buffer matching name {name}")
buffers = [_get_buffer(name) for name in buffer_names]
def _get_buffer(name: str) -> Union[ir.Buffer, ir.TensorBox]:
buf = V.graph.get_buffer(name)
if buf is None:
raise RuntimeError(f"Failed to find buffer matching name {name}")
return buf
buffers = [V.graph.get_buffer(name) for name in buffer_names]
# In theory we can separately check xnumel and rnumel are <= int_max
# but some indexers do use the full linear index so we need to be

View File

@ -660,6 +660,14 @@ class GraphLowering(torch.fx.Interpreter):
return self.name_to_buffer[buffer_name]
if buffer_name in self.graph_inputs:
return self.graph_inputs[buffer_name]
if buffer_name in self.constants:
data = V.graph.constants[buffer_name]
return ir.ConstantBuffer(
buffer_name,
ir.FixedLayout(
data.device, data.dtype, *V.graph.static_sizes_strides(data)
),
)
return None
def get_dtype(self, buffer_name: str):

View File

@ -3,6 +3,7 @@ import logging
from typing import Any, List
import torch
from .. import config
from ..lowering import empty_strided, lowerings, register_lowering
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
@ -114,12 +115,14 @@ sdpa_template = TritonTemplate(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k.to(MATMUL_PRECISION), acc=qk)
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
m = offs_m[:, None]
n = start_n + offs_n[None, :]
{{ modification(
score="qk",
b="off_hz // H",
h="off_hz % H",
m="offs_m[:, None]",
n="start_n + offs_n[None, :]",
m="m",
n="n",
out="qk"
) | indent_except_first(2) }}
# TODO: In the case that score_mod is linear, this can be LICMed
@ -170,7 +173,8 @@ sdpa_template = TritonTemplate(
)
@register_lowering(torch.ops.higher_order.templated_attention)
# TODO: We probably also need a layout constraint?
@register_lowering(torch.ops.higher_order.templated_attention, type_promotion_kind=None)
def templated_attention(*args, **kwargs):
from torch._prims_common import make_contiguous_strides_for
from ..ir import (
@ -182,7 +186,7 @@ def templated_attention(*args, **kwargs):
TensorBox,
)
query, key, value, subgraph = args
query, key, value, subgraph, *other_buffers = args
def create_placeholder(name: str, dtype: torch.dtype) -> InputBuffer:
return TensorBox.create(
@ -272,17 +276,22 @@ def templated_attention(*args, **kwargs):
configs: List[Any] = []
if query.get_dtype() == torch.float32:
configs.append((64, 64, 4, 3))
configs += [
(128, 64, 4, 3),
(128, 128, 4, 3),
(128, 128, 8, 2),
(64, 128, 4, 3),
]
else:
configs.append((128, 64, 4, 3))
if config.max_autotune:
configs += [
(128, 64, 4, 3),
(128, 128, 4, 3),
(128, 128, 8, 2),
(64, 128, 4, 3),
]
# Note, we don't need to pass in the captured buffers explicitly
# because they're implicitly added by the score_mod function
# We do need to explicitly pass it in for autotuning though.
for BLOCK_M, BLOCK_N, num_warps, num_stages in configs:
sdpa_template.maybe_append_choice(
choices=choices,
input_nodes=(query, key, value, logsumexp),
input_nodes=[query, key, value, logsumexp],
layout=layout,
subgraphs=subgraph_buffer,
mutated_inputs=[
@ -298,9 +307,10 @@ def templated_attention(*args, **kwargs):
ROWS_GUARANTEED_SAFE=False,
OUTPUT_LOGSUMEXP=True,
)
inputs_for_autotuning = [query, key, value, logsumexp] + list(other_buffers)
return (
autotune_select_algorithm(
"sdpa", choices, [query, key, value, logsumexp], layout
"sdpa", choices, inputs_for_autotuning, layout
),
logsumexp,
)

View File

@ -37,7 +37,14 @@ from .exc import CUDACompileError
from .ir import ChoiceCaller, PrimitiveInfoType
from .runtime.hints import DeviceProperties
from .runtime.runtime_utils import do_bench
from .utils import get_dtype_size, Placeholder, sympy_dot, sympy_product, unique
from .utils import (
get_dtype_size,
Placeholder,
sympy_dot,
sympy_index_symbol,
sympy_product,
unique,
)
from .virtualized import V
log = logging.getLogger(__name__)
@ -269,20 +276,23 @@ class TritonTemplateKernel(TritonKernel):
potential multiple modifications
"""
def add_input(name):
return self.args.input(name)
class PlaceholderSubstitution(V.WrapperHandler): # type: ignore[name-defined]
self.name = "PlaceholderSubstitution"
def load(self, name: str, index: sympy.Expr):
if name not in fixed_inputs:
raise AssertionError(
f"All loads should be coming from fixed inputs - {name}"
)
# If it's not a fixed input, it's a load from a captured
# tensor
var = add_input(name)
return f"tl.load({var} + {index})"
return f"({fixed_inputs[name]})"
# TODO Doesn't work yet
def indirect_indexing(self, index_var, size, check):
return self._inner.indirect_indexing(index_var, size, False)
# return sympy_symbol(str(index_var))
return sympy_index_symbol(str(index_var))
# if self.modification_cache is None:
with V.set_ops_handler(PlaceholderSubstitution(V.ops)):
@ -589,16 +599,25 @@ class TritonTemplate(KernelTemplate):
+ "-"
)
mod = PyCodeCache.load(code, extra)
_, call_args, _ = kernel.args.python_argdefs()
expected_args = list(unique(x.get_name() for x in input_nodes))
expected_args.extend([fake_out.get_name()])
assert list(call_args)[: len(expected_args)] == expected_args, (
call_args,
expected_args,
input_call_args = tuple(kernel.args.input_buffers.keys())
output_call_args = tuple(kernel.args.output_buffers.keys())
# We expect the input_buffer order to be [*input_nodes, *captured_buffers]
expected_input_args = tuple(unique(x.get_name() for x in input_nodes))
expected_output_args = (fake_out.get_name(),)
assert input_call_args[: len(expected_input_args)] == expected_input_args, (
input_call_args,
expected_input_args,
)
assert output_call_args == expected_output_args, (
output_call_args,
expected_output_args,
)
full_input_nodes = tuple([V.graph.get_buffer(k) for k in input_call_args])
extra_args = V.graph.sizevars.size_hints(
map(sympy.expand, call_args[len(expected_args) :]),
map(sympy.expand, tuple(kernel.args.sizevars.keys())),
fallback=config.unbacked_symint_fallback,
)
@ -636,13 +655,13 @@ class TritonTemplate(KernelTemplate):
num_stages=num_stages,
num_warps=num_warps,
matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
input_tensor_meta=TensorMeta.from_irnodes(input_nodes),
input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes),
output_tensor_meta=TensorMeta.from_irnodes(layout),
)
return TritonTemplateCaller(
kernel_hash_name,
input_nodes,
full_input_nodes,
layout,
make_kernel_render,
extra.strip("-").replace("-", ", "),