mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Optimized templated attention to use exp2 (#124356)
0.705 (vs. FA2) to 0.860 after this change. <img width="1270" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/d58f57ba-e50e-44ea-8a8a-4f13b8650adf"> to <img width="1277" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/f1945b67-0cfc-463c-a2f6-5812b90677fe"> Pull Request resolved: https://github.com/pytorch/pytorch/pull/124356 Approved by: https://github.com/drisspg
This commit is contained in:
@ -6,9 +6,9 @@ from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
import torch.nn.functional as F
|
||||
from tabulate import tabulate
|
||||
from torch.nn.attention._templated_attention import _compose, _templated_attention
|
||||
from torch.nn.attention._templated_attention import _templated_attention
|
||||
from tqdm import tqdm
|
||||
|
||||
torch._dynamo.config.automatic_dynamic_shapes = False
|
||||
@ -16,15 +16,14 @@ torch._dynamo.config.automatic_dynamic_shapes = False
|
||||
torch._dynamo.config.cache_size_limit = 1000
|
||||
|
||||
|
||||
from triton.testing import do_bench
|
||||
|
||||
|
||||
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
func(*args, **kwargs)
|
||||
t0 = benchmark.Timer(
|
||||
stmt="func(*args, **kwargs)",
|
||||
globals={"args": args, "kwargs": kwargs, "func": func},
|
||||
)
|
||||
return t0.adaptive_autorange(min_run_time=0.1).median * 1e6
|
||||
return do_bench(lambda: func(*args, **kwargs)) * 1e3
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -110,8 +109,11 @@ def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
|
||||
config.dtype,
|
||||
device,
|
||||
)
|
||||
eager_sdpa = _templated_attention
|
||||
compiled_sdpa = torch.compile(eager_sdpa)
|
||||
|
||||
def eager_sdpa(query, key, value, _):
|
||||
return F.scaled_dot_product_attention(query, key, value)
|
||||
|
||||
compiled_sdpa = torch.compile(_templated_attention)
|
||||
|
||||
score_mod = config.score_mod
|
||||
|
||||
@ -190,6 +192,9 @@ def print_results(results: List[Experiment]):
|
||||
|
||||
|
||||
def generate_score_mods() -> List[Callable]:
|
||||
def noop(score, b, h, m, n):
|
||||
return score
|
||||
|
||||
def causal_mask(score, b, h, token_q, token_kv):
|
||||
return torch.where(token_q >= token_kv, score, float("-inf"))
|
||||
|
||||
@ -199,14 +204,7 @@ def generate_score_mods() -> List[Callable]:
|
||||
def head_bias(score, b, h, m, n):
|
||||
return score + 2 * h
|
||||
|
||||
def pathological(score, b, h, m, n):
|
||||
def sin(score, b, h, m, n):
|
||||
return torch.sin(score)
|
||||
|
||||
composed_mod = _compose(*(sin for _ in range(10)))
|
||||
return composed_mod(score, b, h, m, n)
|
||||
|
||||
return [causal_mask, relative_bias, head_bias, pathological]
|
||||
return [noop, causal_mask, relative_bias, head_bias]
|
||||
|
||||
|
||||
def generate_experiment_configs() -> List[ExperimentConfig]:
|
||||
|
@ -20,6 +20,7 @@ supported_platform = skipUnless(
|
||||
)
|
||||
|
||||
Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
|
||||
def create_attention(score_mod):
|
||||
@ -49,18 +50,23 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||
q = torch.randn((4, 8, 2048, 64), dtype=dtype, device="cuda")
|
||||
k = torch.randn((4, 8, 2048, 64), dtype=dtype, device="cuda")
|
||||
v = torch.randn((4, 8, 2048, 64), dtype=dtype, device="cuda")
|
||||
ref_out = sdpa_partial(
|
||||
golden_out = sdpa_partial(
|
||||
q.to(torch.float64), k.to(torch.float64), v.to(torch.float64)
|
||||
)
|
||||
ref_out = sdpa_partial(q, k, v)
|
||||
compiled_out = compiled_sdpa(q, k, v)
|
||||
|
||||
tolerance = Tolerances(atol=2e-2, rtol=2e-2)
|
||||
torch.testing.assert_close(
|
||||
ref_out.to(dtype=torch.float32),
|
||||
compiled_out.to(dtype=torch.float32),
|
||||
atol=tolerance.atol,
|
||||
rtol=tolerance.rtol,
|
||||
)
|
||||
compiled_error = (golden_out - compiled_out).abs().mean()
|
||||
ref_error = (golden_out - ref_out).abs().mean()
|
||||
# Note, it seems like we really are less accurate than the float32
|
||||
# computation, likely due to the online softmax
|
||||
if dtype == torch.float32:
|
||||
fudge_factor = 4.0
|
||||
else:
|
||||
fudge_factor = 1.1
|
||||
if compiled_error > ref_error * fudge_factor:
|
||||
msg = f"Compiled error {compiled_error} is greater than ref error {ref_error} by more than 10%."
|
||||
self.assertTrue(False, msg)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes)
|
||||
@ -102,6 +108,14 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||
|
||||
self.run_test(score_mod, dtype)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes)
|
||||
def test_skip_odd_keys(self, dtype: torch.dtype):
|
||||
def score_mod(score, b, h, q, kv):
|
||||
return torch.where(kv % 2 == 0, score, float("-inf"))
|
||||
|
||||
self.run_test(score_mod, dtype)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes)
|
||||
def test_alibi_causal(self, dtype: torch.dtype):
|
||||
|
@ -1,8 +1,10 @@
|
||||
""" Triton Implementation of the Templated SDPA Kernel"""
|
||||
import logging
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from ..select_algorithm import TritonTemplate
|
||||
from ..lowering import lowerings, register_lowering
|
||||
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
aten = torch.ops.aten
|
||||
@ -28,6 +30,13 @@ sdpa_template = TritonTemplate(
|
||||
# Q: Query, K: Key, V: Value
|
||||
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
||||
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
|
||||
# (Modifiable) Config options:
|
||||
# BLOCK_M
|
||||
# BLOCK_N
|
||||
# SCORE_MOD_IS_LINEAR: Is the score modifier linear? If so, we can lift the
|
||||
# change of base out of the loop
|
||||
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
||||
# is not masked out? If so, we can skip an extra safety check
|
||||
|
||||
# Define Q Strides
|
||||
stride_qz = {{stride("Q", 0)}}
|
||||
@ -49,10 +58,8 @@ sdpa_template = TritonTemplate(
|
||||
H = {{size("Q", 1)}}
|
||||
N_CTX = {{size("Q", 2)}}
|
||||
|
||||
# TODO I think we should do some performance work
|
||||
# to find the optimal calls for perf/accuracy to tl.dot
|
||||
qk_scale = 1.0
|
||||
MATMUL_PRECISION = tl.float16
|
||||
MATMUL_PRECISION = Q.dtype.element_ty
|
||||
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
@ -89,12 +96,10 @@ sdpa_template = TritonTemplate(
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# scale sm_scale by log_2(e) and use
|
||||
# 2^x instead of exp in the loop because CSE and LICM
|
||||
# don't work as expected with `exp` in the loop
|
||||
# TODO fix me
|
||||
# qk_scale = sm_scale * 1.44269504
|
||||
|
||||
q = tl.load(Q_block_ptr)
|
||||
if SCORE_MOD_IS_LINEAR:
|
||||
qk_scale *= 1.44269504
|
||||
q = (q * qk_scale).to(MATMUL_PRECISION)
|
||||
# loop over k, v and update accumulator
|
||||
lo = 0
|
||||
@ -106,9 +111,8 @@ sdpa_template = TritonTemplate(
|
||||
v = tl.load(V_block_ptr)
|
||||
# -- compute qk ---
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k.to(MATMUL_PRECISION))
|
||||
qk = tl.dot(q, k.to(MATMUL_PRECISION), acc=qk)
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
{{ modification(
|
||||
score="qk",
|
||||
b="off_hz // H",
|
||||
@ -117,6 +121,9 @@ sdpa_template = TritonTemplate(
|
||||
n="start_n + offs_n[None, :]",
|
||||
out="qk"
|
||||
) | indent_except_first(2) }}
|
||||
# TODO: In the case that score_mod is linear, this can be LICMed
|
||||
if not SCORE_MOD_IS_LINEAR:
|
||||
qk *= 1.44269504
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
# -- compute scaling constant ---
|
||||
@ -124,18 +131,16 @@ sdpa_template = TritonTemplate(
|
||||
m_i_new = tl.maximum(m_i, row_max)
|
||||
masked_out_rows = (m_i_new == float("-inf"))
|
||||
|
||||
# TODO FIX ME and use 2^x instead of exp
|
||||
# alpha = tl.math.exp2(m_i - m_i_new)
|
||||
# p = tl.math.exp2(qk - m_i_new[:, None])
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
alpha = tl.where(masked_out_rows, 0, alpha)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
p = tl.where(masked_out_rows[:, None], 0, p)
|
||||
alpha = tl.math.exp2(m_i - m_i_new)
|
||||
p = tl.math.exp2(qk - m_i_new[:, None])
|
||||
if not ROWS_GUARANTEED_SAFE:
|
||||
alpha = tl.where(masked_out_rows, 0, alpha)
|
||||
p = tl.where(masked_out_rows[:, None], 0, p)
|
||||
|
||||
# -- scale and update acc --
|
||||
acc_scale = l_i * 0 + alpha # workaround some compiler bug
|
||||
acc *= acc_scale[:, None]
|
||||
acc += tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION))
|
||||
acc = tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION), acc)
|
||||
|
||||
# -- update m_i and l_i --
|
||||
l_i = l_i * alpha + tl.sum(p, 1)
|
||||
@ -159,3 +164,125 @@ sdpa_template = TritonTemplate(
|
||||
{{store_output(("idx_z", "idx_h", "idx_m", "idx_d"), "acc")}}
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
@register_lowering(torch.ops.higher_order.templated_attention)
|
||||
def templated_attention(*args, **kwargs):
|
||||
from torch._prims_common import make_contiguous_strides_for
|
||||
from ..ir import (
|
||||
ComputedBuffer,
|
||||
FixedLayout,
|
||||
FlexibleLayout,
|
||||
InputBuffer,
|
||||
StorageBox,
|
||||
TensorBox,
|
||||
)
|
||||
|
||||
query, key, value, subgraph = args
|
||||
|
||||
def create_placeholder(name: str, dtype: torch.dtype) -> InputBuffer:
|
||||
return TensorBox.create(
|
||||
InputBuffer(
|
||||
name,
|
||||
FixedLayout(
|
||||
query.get_device(),
|
||||
dtype,
|
||||
[
|
||||
1,
|
||||
],
|
||||
[
|
||||
1,
|
||||
],
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
scalar_inps = ["score", "b", "h", "m", "n"]
|
||||
env = {}
|
||||
cnt = 0
|
||||
placeholder_inps = [
|
||||
create_placeholder(name, dtype)
|
||||
for name, dtype in [
|
||||
("score", query.get_dtype()),
|
||||
("b", torch.int64),
|
||||
("h", torch.int64),
|
||||
("m", torch.int64),
|
||||
("n", torch.int64),
|
||||
]
|
||||
]
|
||||
for node in subgraph.graph_module.graph.nodes:
|
||||
# There are two classes of placeholder inpts that we need
|
||||
# to handle differently. For the first n_scalar_inps inputs
|
||||
# we expect that these placeholders were generated by the make_fx call
|
||||
# in the templated Attention HOP. So we need to create a new placeholder
|
||||
# TensorBox for each of these inputs. For the rest of the inputs we
|
||||
# expect that these are lifted inputs that fill up the '*other_buffers'
|
||||
# tuple and already have corresponding TensorBoxes passed in as args.
|
||||
if node.op == "placeholder":
|
||||
is_lifted_input = cnt >= len(scalar_inps)
|
||||
env[node] = args[cnt - 1] if is_lifted_input else placeholder_inps[cnt]
|
||||
cnt += 1
|
||||
elif node.op == "call_function":
|
||||
# For call_function we use the defulat lowerings and pass in the
|
||||
# already created TensorBoxes as args
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
env[node] = lowerings[node.target](
|
||||
*tree_map(lambda x: env[x] if x in env else x, node.args)
|
||||
)
|
||||
elif node.op == "output":
|
||||
# For the output node we need to create a ComputedBuffer
|
||||
# which represents the actual score modification
|
||||
|
||||
output_buffer = env[node.args[0]]
|
||||
assert isinstance(output_buffer.data, StorageBox), (
|
||||
"The output node for the templated attention subgraph must be a StorageBox, but got: ",
|
||||
type(output_buffer),
|
||||
)
|
||||
# Create the ComputedBuffere directly that will be inlined into the modfication block
|
||||
subgraph_buffer = ComputedBuffer(
|
||||
name=None,
|
||||
layout=FlexibleLayout(
|
||||
device=output_buffer.data.get_device(),
|
||||
dtype=output_buffer.data.get_dtype(),
|
||||
size=output_buffer.data.get_size(),
|
||||
),
|
||||
data=output_buffer.data.data, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
layout = FixedLayout(
|
||||
output_buffer.get_device(),
|
||||
query.get_dtype(),
|
||||
query.get_size(),
|
||||
make_contiguous_strides_for(query.get_size()),
|
||||
)
|
||||
choices: List[Any] = []
|
||||
configs: List[Any] = []
|
||||
if query.get_dtype() == torch.float32:
|
||||
configs.append((64, 64, 4, 3))
|
||||
configs += [
|
||||
(128, 64, 4, 3),
|
||||
(128, 128, 4, 3),
|
||||
(128, 128, 8, 2),
|
||||
(64, 128, 4, 3),
|
||||
]
|
||||
|
||||
for BLOCK_M, BLOCK_N, num_warps, num_stages in configs:
|
||||
sdpa_template.maybe_append_choice(
|
||||
choices=choices,
|
||||
input_nodes=(query, key, value),
|
||||
layout=layout,
|
||||
subgraphs=subgraph_buffer,
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
BLOCK_DMODEL=query.get_size()[-1],
|
||||
# For now, we always assume the "sound" option
|
||||
SCORE_MOD_IS_LINEAR=False,
|
||||
ROWS_GUARANTEED_SAFE=False,
|
||||
)
|
||||
return autotune_select_algorithm(
|
||||
"sdpa", choices, [query, key, value], layout
|
||||
)
|
||||
raise ValueError("TemplatedAttention was passed a subgraph with no output node!")
|
||||
|
@ -5623,123 +5623,6 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs):
|
||||
return list(map(TensorBox.create, result))
|
||||
|
||||
|
||||
@register_lowering(torch.ops.higher_order.templated_attention)
|
||||
def templated_attention(*args, **kwargs):
|
||||
from torch._prims_common import make_contiguous_strides_for
|
||||
from .ir import (
|
||||
ComputedBuffer,
|
||||
FixedLayout,
|
||||
FlexibleLayout,
|
||||
InputBuffer,
|
||||
StorageBox,
|
||||
TensorBox,
|
||||
)
|
||||
|
||||
query, key, value, subgraph = args
|
||||
|
||||
def create_placeholder(name: str, dtype: torch.dtype) -> InputBuffer:
|
||||
return TensorBox.create(
|
||||
InputBuffer(
|
||||
name,
|
||||
FixedLayout(
|
||||
query.get_device(),
|
||||
dtype,
|
||||
[
|
||||
1,
|
||||
],
|
||||
[
|
||||
1,
|
||||
],
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
scalar_inps = ["score", "b", "h", "m", "n"]
|
||||
env = {}
|
||||
cnt = 0
|
||||
placeholder_inps = [
|
||||
create_placeholder(name, dtype)
|
||||
for name, dtype in [
|
||||
("score", query.get_dtype()),
|
||||
("b", torch.int64),
|
||||
("h", torch.int64),
|
||||
("m", torch.int64),
|
||||
("n", torch.int64),
|
||||
]
|
||||
]
|
||||
for node in subgraph.graph_module.graph.nodes:
|
||||
# There are two classes of placeholder inpts that we need
|
||||
# to handle differently. For the first n_scalar_inps inputs
|
||||
# we expect that these placeholders were generated by the make_fx call
|
||||
# in the templated Attention HOP. So we need to create a new placeholder
|
||||
# TensorBox for each of these inputs. For the rest of the inputs we
|
||||
# expect that these are lifted inputs that fill up the '*other_buffers'
|
||||
# tuple and already have corresponding TensorBoxes passed in as args.
|
||||
if node.op == "placeholder":
|
||||
is_lifted_input = cnt >= len(scalar_inps)
|
||||
env[node] = args[cnt - 1] if is_lifted_input else placeholder_inps[cnt]
|
||||
cnt += 1
|
||||
elif node.op == "call_function":
|
||||
# For call_function we use the defulat lowerings and pass in the
|
||||
# already created TensorBoxes as args
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
env[node] = lowerings[node.target](
|
||||
*tree_map(lambda x: env[x] if x in env else x, node.args)
|
||||
)
|
||||
elif node.op == "output":
|
||||
# For the output node we need to create a ComputedBuffer
|
||||
# which represents the actual score modification
|
||||
|
||||
output_buffer = env[node.args[0]]
|
||||
assert isinstance(output_buffer.data, StorageBox), (
|
||||
"The output node for the templated attention subgraph must be a StorageBox, but got: ",
|
||||
type(output_buffer),
|
||||
)
|
||||
# Create the ComputedBuffere directly that will be inlined into the modfication block
|
||||
subgraph_buffer = ComputedBuffer(
|
||||
name=None,
|
||||
layout=FlexibleLayout(
|
||||
device=output_buffer.data.get_device(),
|
||||
dtype=output_buffer.data.get_dtype(),
|
||||
size=output_buffer.data.get_size(),
|
||||
),
|
||||
data=output_buffer.data.data, # type: ignore[arg-type]
|
||||
)
|
||||
from .kernel.templated_attention import sdpa_template
|
||||
|
||||
layout = FixedLayout(
|
||||
output_buffer.get_device(),
|
||||
query.get_dtype(),
|
||||
query.get_size(),
|
||||
make_contiguous_strides_for(query.get_size()),
|
||||
)
|
||||
choices: List[Any] = []
|
||||
from .select_algorithm import autotune_select_algorithm
|
||||
|
||||
for BLOCK_M, BLOCK_N, num_warps, num_stages in [
|
||||
(128, 64, 4, 3),
|
||||
(128, 128, 4, 3),
|
||||
(128, 128, 8, 2),
|
||||
(64, 128, 4, 3),
|
||||
]:
|
||||
sdpa_template.maybe_append_choice(
|
||||
choices=choices,
|
||||
input_nodes=(query, key, value),
|
||||
layout=layout,
|
||||
subgraphs=subgraph_buffer,
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
BLOCK_DMODEL=query.get_size()[-1],
|
||||
)
|
||||
return autotune_select_algorithm(
|
||||
"sdpa", choices, [query, key, value], layout
|
||||
)
|
||||
raise ValueError("TemplatedAttention was passed a subgraph with no output node!")
|
||||
|
||||
|
||||
@register_lowering(torch.ops.prims._sink_tokens.default)
|
||||
def _sink_tokens(tokens):
|
||||
return None
|
||||
|
@ -206,7 +206,13 @@ class CachingAutotuner(KernelInterface):
|
||||
compiled_binary, launcher = self._precompile_config(
|
||||
c, warm_cache_only_with_cc
|
||||
)
|
||||
except OutOfResources:
|
||||
except OutOfResources as e:
|
||||
if len(self.configs) == 1:
|
||||
raise RuntimeError(
|
||||
f"Failed to compile triton config: {c}. "
|
||||
f"Report a fatal compilation error. "
|
||||
f"{e}"
|
||||
)
|
||||
# Skip the config if we run out of resource
|
||||
continue
|
||||
self.launchers.append(launcher)
|
||||
|
Reference in New Issue
Block a user