Optimized templated attention to use exp2 (#124356)

0.705 (vs. FA2) to 0.860 after this change.

<img width="1270" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/d58f57ba-e50e-44ea-8a8a-4f13b8650adf">

to

<img width="1277" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/f1945b67-0cfc-463c-a2f6-5812b90677fe">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124356
Approved by: https://github.com/drisspg
This commit is contained in:
chilli
2024-04-18 11:40:41 -07:00
committed by PyTorch MergeBot
parent 0bde4efa84
commit e620c3e814
5 changed files with 190 additions and 162 deletions

View File

@ -6,9 +6,9 @@ from typing import Callable, List
import numpy as np
import torch
import torch.utils.benchmark as benchmark
import torch.nn.functional as F
from tabulate import tabulate
from torch.nn.attention._templated_attention import _compose, _templated_attention
from torch.nn.attention._templated_attention import _templated_attention
from tqdm import tqdm
torch._dynamo.config.automatic_dynamic_shapes = False
@ -16,15 +16,14 @@ torch._dynamo.config.automatic_dynamic_shapes = False
torch._dynamo.config.cache_size_limit = 1000
from triton.testing import do_bench
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
# warmup
for _ in range(5):
func(*args, **kwargs)
t0 = benchmark.Timer(
stmt="func(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "func": func},
)
return t0.adaptive_autorange(min_run_time=0.1).median * 1e6
return do_bench(lambda: func(*args, **kwargs)) * 1e3
@dataclass(frozen=True)
@ -110,8 +109,11 @@ def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
config.dtype,
device,
)
eager_sdpa = _templated_attention
compiled_sdpa = torch.compile(eager_sdpa)
def eager_sdpa(query, key, value, _):
return F.scaled_dot_product_attention(query, key, value)
compiled_sdpa = torch.compile(_templated_attention)
score_mod = config.score_mod
@ -190,6 +192,9 @@ def print_results(results: List[Experiment]):
def generate_score_mods() -> List[Callable]:
def noop(score, b, h, m, n):
return score
def causal_mask(score, b, h, token_q, token_kv):
return torch.where(token_q >= token_kv, score, float("-inf"))
@ -199,14 +204,7 @@ def generate_score_mods() -> List[Callable]:
def head_bias(score, b, h, m, n):
return score + 2 * h
def pathological(score, b, h, m, n):
def sin(score, b, h, m, n):
return torch.sin(score)
composed_mod = _compose(*(sin for _ in range(10)))
return composed_mod(score, b, h, m, n)
return [causal_mask, relative_bias, head_bias, pathological]
return [noop, causal_mask, relative_bias, head_bias]
def generate_experiment_configs() -> List[ExperimentConfig]:

View File

@ -20,6 +20,7 @@ supported_platform = skipUnless(
)
Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
torch.set_float32_matmul_precision("high")
def create_attention(score_mod):
@ -49,18 +50,23 @@ class TestTemplatedSDPA(InductorTestCase):
q = torch.randn((4, 8, 2048, 64), dtype=dtype, device="cuda")
k = torch.randn((4, 8, 2048, 64), dtype=dtype, device="cuda")
v = torch.randn((4, 8, 2048, 64), dtype=dtype, device="cuda")
ref_out = sdpa_partial(
golden_out = sdpa_partial(
q.to(torch.float64), k.to(torch.float64), v.to(torch.float64)
)
ref_out = sdpa_partial(q, k, v)
compiled_out = compiled_sdpa(q, k, v)
tolerance = Tolerances(atol=2e-2, rtol=2e-2)
torch.testing.assert_close(
ref_out.to(dtype=torch.float32),
compiled_out.to(dtype=torch.float32),
atol=tolerance.atol,
rtol=tolerance.rtol,
)
compiled_error = (golden_out - compiled_out).abs().mean()
ref_error = (golden_out - ref_out).abs().mean()
# Note, it seems like we really are less accurate than the float32
# computation, likely due to the online softmax
if dtype == torch.float32:
fudge_factor = 4.0
else:
fudge_factor = 1.1
if compiled_error > ref_error * fudge_factor:
msg = f"Compiled error {compiled_error} is greater than ref error {ref_error} by more than 10%."
self.assertTrue(False, msg)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
@ -102,6 +108,14 @@ class TestTemplatedSDPA(InductorTestCase):
self.run_test(score_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_skip_odd_keys(self, dtype: torch.dtype):
def score_mod(score, b, h, q, kv):
return torch.where(kv % 2 == 0, score, float("-inf"))
self.run_test(score_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_alibi_causal(self, dtype: torch.dtype):

View File

@ -1,8 +1,10 @@
""" Triton Implementation of the Templated SDPA Kernel"""
import logging
from typing import Any, List
import torch
from ..select_algorithm import TritonTemplate
from ..lowering import lowerings, register_lowering
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
log = logging.getLogger(__name__)
aten = torch.ops.aten
@ -28,6 +30,13 @@ sdpa_template = TritonTemplate(
# Q: Query, K: Key, V: Value
# M: Number of queries, N: Number of keys/values, D: Model dimension
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
# (Modifiable) Config options:
# BLOCK_M
# BLOCK_N
# SCORE_MOD_IS_LINEAR: Is the score modifier linear? If so, we can lift the
# change of base out of the loop
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
# is not masked out? If so, we can skip an extra safety check
# Define Q Strides
stride_qz = {{stride("Q", 0)}}
@ -49,10 +58,8 @@ sdpa_template = TritonTemplate(
H = {{size("Q", 1)}}
N_CTX = {{size("Q", 2)}}
# TODO I think we should do some performance work
# to find the optimal calls for perf/accuracy to tl.dot
qk_scale = 1.0
MATMUL_PRECISION = tl.float16
MATMUL_PRECISION = Q.dtype.element_ty
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
@ -89,12 +96,10 @@ sdpa_template = TritonTemplate(
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
# TODO fix me
# qk_scale = sm_scale * 1.44269504
q = tl.load(Q_block_ptr)
if SCORE_MOD_IS_LINEAR:
qk_scale *= 1.44269504
q = (q * qk_scale).to(MATMUL_PRECISION)
# loop over k, v and update accumulator
lo = 0
@ -106,9 +111,8 @@ sdpa_template = TritonTemplate(
v = tl.load(V_block_ptr)
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k.to(MATMUL_PRECISION))
qk = tl.dot(q, k.to(MATMUL_PRECISION), acc=qk)
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
{{ modification(
score="qk",
b="off_hz // H",
@ -117,6 +121,9 @@ sdpa_template = TritonTemplate(
n="start_n + offs_n[None, :]",
out="qk"
) | indent_except_first(2) }}
# TODO: In the case that score_mod is linear, this can be LICMed
if not SCORE_MOD_IS_LINEAR:
qk *= 1.44269504
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# -- compute scaling constant ---
@ -124,18 +131,16 @@ sdpa_template = TritonTemplate(
m_i_new = tl.maximum(m_i, row_max)
masked_out_rows = (m_i_new == float("-inf"))
# TODO FIX ME and use 2^x instead of exp
# alpha = tl.math.exp2(m_i - m_i_new)
# p = tl.math.exp2(qk - m_i_new[:, None])
alpha = tl.math.exp(m_i - m_i_new)
alpha = tl.where(masked_out_rows, 0, alpha)
p = tl.math.exp(qk - m_i_new[:, None])
p = tl.where(masked_out_rows[:, None], 0, p)
alpha = tl.math.exp2(m_i - m_i_new)
p = tl.math.exp2(qk - m_i_new[:, None])
if not ROWS_GUARANTEED_SAFE:
alpha = tl.where(masked_out_rows, 0, alpha)
p = tl.where(masked_out_rows[:, None], 0, p)
# -- scale and update acc --
acc_scale = l_i * 0 + alpha # workaround some compiler bug
acc *= acc_scale[:, None]
acc += tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION))
acc = tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION), acc)
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
@ -159,3 +164,125 @@ sdpa_template = TritonTemplate(
{{store_output(("idx_z", "idx_h", "idx_m", "idx_d"), "acc")}}
""",
)
@register_lowering(torch.ops.higher_order.templated_attention)
def templated_attention(*args, **kwargs):
from torch._prims_common import make_contiguous_strides_for
from ..ir import (
ComputedBuffer,
FixedLayout,
FlexibleLayout,
InputBuffer,
StorageBox,
TensorBox,
)
query, key, value, subgraph = args
def create_placeholder(name: str, dtype: torch.dtype) -> InputBuffer:
return TensorBox.create(
InputBuffer(
name,
FixedLayout(
query.get_device(),
dtype,
[
1,
],
[
1,
],
),
)
)
scalar_inps = ["score", "b", "h", "m", "n"]
env = {}
cnt = 0
placeholder_inps = [
create_placeholder(name, dtype)
for name, dtype in [
("score", query.get_dtype()),
("b", torch.int64),
("h", torch.int64),
("m", torch.int64),
("n", torch.int64),
]
]
for node in subgraph.graph_module.graph.nodes:
# There are two classes of placeholder inpts that we need
# to handle differently. For the first n_scalar_inps inputs
# we expect that these placeholders were generated by the make_fx call
# in the templated Attention HOP. So we need to create a new placeholder
# TensorBox for each of these inputs. For the rest of the inputs we
# expect that these are lifted inputs that fill up the '*other_buffers'
# tuple and already have corresponding TensorBoxes passed in as args.
if node.op == "placeholder":
is_lifted_input = cnt >= len(scalar_inps)
env[node] = args[cnt - 1] if is_lifted_input else placeholder_inps[cnt]
cnt += 1
elif node.op == "call_function":
# For call_function we use the defulat lowerings and pass in the
# already created TensorBoxes as args
from torch.utils._pytree import tree_map
env[node] = lowerings[node.target](
*tree_map(lambda x: env[x] if x in env else x, node.args)
)
elif node.op == "output":
# For the output node we need to create a ComputedBuffer
# which represents the actual score modification
output_buffer = env[node.args[0]]
assert isinstance(output_buffer.data, StorageBox), (
"The output node for the templated attention subgraph must be a StorageBox, but got: ",
type(output_buffer),
)
# Create the ComputedBuffere directly that will be inlined into the modfication block
subgraph_buffer = ComputedBuffer(
name=None,
layout=FlexibleLayout(
device=output_buffer.data.get_device(),
dtype=output_buffer.data.get_dtype(),
size=output_buffer.data.get_size(),
),
data=output_buffer.data.data, # type: ignore[arg-type]
)
layout = FixedLayout(
output_buffer.get_device(),
query.get_dtype(),
query.get_size(),
make_contiguous_strides_for(query.get_size()),
)
choices: List[Any] = []
configs: List[Any] = []
if query.get_dtype() == torch.float32:
configs.append((64, 64, 4, 3))
configs += [
(128, 64, 4, 3),
(128, 128, 4, 3),
(128, 128, 8, 2),
(64, 128, 4, 3),
]
for BLOCK_M, BLOCK_N, num_warps, num_stages in configs:
sdpa_template.maybe_append_choice(
choices=choices,
input_nodes=(query, key, value),
layout=layout,
subgraphs=subgraph_buffer,
num_stages=num_stages,
num_warps=num_warps,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=query.get_size()[-1],
# For now, we always assume the "sound" option
SCORE_MOD_IS_LINEAR=False,
ROWS_GUARANTEED_SAFE=False,
)
return autotune_select_algorithm(
"sdpa", choices, [query, key, value], layout
)
raise ValueError("TemplatedAttention was passed a subgraph with no output node!")

View File

@ -5623,123 +5623,6 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs):
return list(map(TensorBox.create, result))
@register_lowering(torch.ops.higher_order.templated_attention)
def templated_attention(*args, **kwargs):
from torch._prims_common import make_contiguous_strides_for
from .ir import (
ComputedBuffer,
FixedLayout,
FlexibleLayout,
InputBuffer,
StorageBox,
TensorBox,
)
query, key, value, subgraph = args
def create_placeholder(name: str, dtype: torch.dtype) -> InputBuffer:
return TensorBox.create(
InputBuffer(
name,
FixedLayout(
query.get_device(),
dtype,
[
1,
],
[
1,
],
),
)
)
scalar_inps = ["score", "b", "h", "m", "n"]
env = {}
cnt = 0
placeholder_inps = [
create_placeholder(name, dtype)
for name, dtype in [
("score", query.get_dtype()),
("b", torch.int64),
("h", torch.int64),
("m", torch.int64),
("n", torch.int64),
]
]
for node in subgraph.graph_module.graph.nodes:
# There are two classes of placeholder inpts that we need
# to handle differently. For the first n_scalar_inps inputs
# we expect that these placeholders were generated by the make_fx call
# in the templated Attention HOP. So we need to create a new placeholder
# TensorBox for each of these inputs. For the rest of the inputs we
# expect that these are lifted inputs that fill up the '*other_buffers'
# tuple and already have corresponding TensorBoxes passed in as args.
if node.op == "placeholder":
is_lifted_input = cnt >= len(scalar_inps)
env[node] = args[cnt - 1] if is_lifted_input else placeholder_inps[cnt]
cnt += 1
elif node.op == "call_function":
# For call_function we use the defulat lowerings and pass in the
# already created TensorBoxes as args
from torch.utils._pytree import tree_map
env[node] = lowerings[node.target](
*tree_map(lambda x: env[x] if x in env else x, node.args)
)
elif node.op == "output":
# For the output node we need to create a ComputedBuffer
# which represents the actual score modification
output_buffer = env[node.args[0]]
assert isinstance(output_buffer.data, StorageBox), (
"The output node for the templated attention subgraph must be a StorageBox, but got: ",
type(output_buffer),
)
# Create the ComputedBuffere directly that will be inlined into the modfication block
subgraph_buffer = ComputedBuffer(
name=None,
layout=FlexibleLayout(
device=output_buffer.data.get_device(),
dtype=output_buffer.data.get_dtype(),
size=output_buffer.data.get_size(),
),
data=output_buffer.data.data, # type: ignore[arg-type]
)
from .kernel.templated_attention import sdpa_template
layout = FixedLayout(
output_buffer.get_device(),
query.get_dtype(),
query.get_size(),
make_contiguous_strides_for(query.get_size()),
)
choices: List[Any] = []
from .select_algorithm import autotune_select_algorithm
for BLOCK_M, BLOCK_N, num_warps, num_stages in [
(128, 64, 4, 3),
(128, 128, 4, 3),
(128, 128, 8, 2),
(64, 128, 4, 3),
]:
sdpa_template.maybe_append_choice(
choices=choices,
input_nodes=(query, key, value),
layout=layout,
subgraphs=subgraph_buffer,
num_stages=num_stages,
num_warps=num_warps,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=query.get_size()[-1],
)
return autotune_select_algorithm(
"sdpa", choices, [query, key, value], layout
)
raise ValueError("TemplatedAttention was passed a subgraph with no output node!")
@register_lowering(torch.ops.prims._sink_tokens.default)
def _sink_tokens(tokens):
return None

View File

@ -206,7 +206,13 @@ class CachingAutotuner(KernelInterface):
compiled_binary, launcher = self._precompile_config(
c, warm_cache_only_with_cc
)
except OutOfResources:
except OutOfResources as e:
if len(self.configs) == 1:
raise RuntimeError(
f"Failed to compile triton config: {c}. "
f"Report a fatal compilation error. "
f"{e}"
)
# Skip the config if we run out of resource
continue
self.launchers.append(launcher)