Compare commits

...

1 Commits

Author SHA1 Message Date
cf1bb45476 add option for specifying backend
ghstack-source-id: e9620f5fe3577ca23f0db04b0126ca8687424c1d
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133099
2024-08-09 11:04:32 -07:00

View File

@ -2,6 +2,7 @@ import argparse
import csv
import itertools
from collections import defaultdict
from contextlib import nullcontext
from dataclasses import asdict, dataclass
from functools import partial
from typing import Callable, List, Optional, Tuple
@ -12,6 +13,7 @@ from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.attention.flex_attention import (
_create_empty_block_mask,
create_block_mask,
@ -42,6 +44,7 @@ class ExperimentConfig:
dtype: torch.dtype
calculate_bwd_time: bool
cal_bandwidth: bool
backend: str
def __post_init__(self):
assert (
@ -123,78 +126,104 @@ def generate_inputs(
return query, key, value
@sdpa_kernel(SDPBackend.CUDNN_ATTENTION)
def run_single_experiment(
config: ExperimentConfig,
dynamic=False,
max_autotune=False,
) -> ExperimentResults:
device = torch.device("cuda")
batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape
query, key, value = generate_inputs(
batch_size,
q_heads,
q_seq_len,
kv_heads,
kv_seq_len,
head_dim,
config.dtype,
device,
requires_grad=config.calculate_bwd_time,
)
kwargs = {}
if get_func_name(config.mask_mod) == "causal":
kwargs["is_causal"] = True
def eager_sdpa(query, key, value, _):
return F.scaled_dot_product_attention(query, key, value, **kwargs)
if max_autotune:
compiled_sdpa = torch.compile(
flex_attention, dynamic=dynamic, mode="max-autotune-no-cudagraphs"
)
else:
compiled_sdpa = torch.compile(flex_attention, dynamic=dynamic)
score_mod = config.score_mod
mask_mod = config.mask_mod
if mask_mod:
block_mask = create_block_mask(
mask_mod, 1, 1, q_seq_len * (q_heads // kv_heads), kv_seq_len, query.device
)
else:
block_mask = _create_empty_block_mask(query, key)
forward_eager_time = benchmark_torch_function_in_microseconds(
eager_sdpa, query, key, value, score_mod
)
forward_compiled_time = benchmark_torch_function_in_microseconds(
compiled_sdpa, query, key, value, score_mod, block_mask
)
if config.calculate_bwd_time:
out_eager = eager_sdpa(query, key, value, score_mod)
dOut = torch.randn_like(out_eager)
backward_eager_time = benchmark_torch_function_in_microseconds(
out_eager.backward, dOut, retain_graph=True
backend_context = get_backend_context(config.backend)
with backend_context:
device = torch.device("cuda")
batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape
query, key, value = generate_inputs(
batch_size,
q_heads,
q_seq_len,
kv_heads,
kv_seq_len,
head_dim,
config.dtype,
device,
requires_grad=config.calculate_bwd_time,
)
out_compile = compiled_sdpa(query, key, value, score_mod, block_mask)
dOut = torch.randn_like(out_eager)
backward_compile_time = benchmark_torch_function_in_microseconds(
out_compile.backward, dOut, retain_graph=True
kwargs = {}
if get_func_name(config.mask_mod) == "causal":
kwargs["is_causal"] = True
if config.backend != "fav3":
def eager_sdpa(query, key, value, _):
return F.scaled_dot_product_attention(query, key, value, **kwargs)
else:
try:
from flash_attn_interface import flash_attn_func
except ImportError:
print(
"Flash attention 3 is not installed. Please install it to run with fav3 backend."
)
raise
kwargs["causal"] = kwargs.pop("is_causal", False)
def eager_sdpa(query, key, value, _):
q = query.transpose(1, 2)
k = key.transpose(1, 2)
v = value.transpose(1, 2)
return (flash_attn_func(q, k, v, **kwargs)[0]).transpose(1, 2)
if max_autotune:
compiled_sdpa = torch.compile(
flex_attention, dynamic=dynamic, mode="max-autotune-no-cudagraphs"
)
else:
compiled_sdpa = torch.compile(flex_attention, dynamic=dynamic)
score_mod = config.score_mod
mask_mod = config.mask_mod
if mask_mod:
block_mask = create_block_mask(
mask_mod,
1,
1,
q_seq_len * (q_heads // kv_heads),
kv_seq_len,
query.device,
)
else:
block_mask = _create_empty_block_mask(query, key)
forward_eager_time = benchmark_torch_function_in_microseconds(
eager_sdpa, query, key, value, score_mod
)
forward_compiled_time = benchmark_torch_function_in_microseconds(
compiled_sdpa, query, key, value, score_mod, block_mask
)
return ExperimentResults(
fwd_times=Times(forward_eager_time, forward_compiled_time),
bwd_times=Times(backward_eager_time, backward_compile_time),
)
else:
return ExperimentResults(
fwd_times=Times(forward_eager_time, forward_compiled_time),
bwd_times=None,
)
if config.calculate_bwd_time:
out_eager = eager_sdpa(query, key, value, score_mod)
dOut = torch.randn_like(out_eager)
backward_eager_time = benchmark_torch_function_in_microseconds(
out_eager.backward, dOut, retain_graph=True
)
out_compile = compiled_sdpa(query, key, value, score_mod, block_mask)
dOut = torch.randn_like(out_eager)
backward_compile_time = benchmark_torch_function_in_microseconds(
out_compile.backward, dOut, retain_graph=True
)
return ExperimentResults(
fwd_times=Times(forward_eager_time, forward_compiled_time),
bwd_times=Times(backward_eager_time, backward_compile_time),
)
else:
return ExperimentResults(
fwd_times=Times(forward_eager_time, forward_compiled_time),
bwd_times=None,
)
def calculate_speedup(results: ExperimentResults, type: str) -> float:
@ -377,6 +406,37 @@ def generate_score_mods(score_mods: List[str]) -> List[Callable | None]:
return [function_dict[name] for name in score_mods]
def get_backend_context(backend: str):
"""
Returns a context manager for the specified backend.
Args:
backend (str): The name of the backend to use.
Valid options are 'default', 'flash', and 'fav3'.
Returns:
A context manager for the specified backend.
Raises:
ValueError: If an invalid backend is specified.
"""
backends = {
"default": nullcontext(),
"fav2": sdpa_kernel(SDPBackend.FLASH_ATTENTION),
"cudnn": sdpa_kernel(SDPBackend.CUDNN_ATTENTION),
"math": sdpa_kernel(SDPBackend.MATH),
"efficient": sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION),
"fav3": nullcontext(),
}
if backend not in backends:
raise ValueError(
f"Unknown backend: {backend}. Valid options are: {', '.join(backends.keys())}"
)
return backends[backend]
def generate_mask_mods(score_mods: List[str]) -> List[Callable | None]:
def noop(b, h, m, n):
return True
@ -434,6 +494,7 @@ def generate_flash_configs(
decoding: bool,
kv_cache_size: List[int],
cal_bandwidth: bool,
backend: str,
) -> List[ExperimentConfig]:
assert not (calculate_bwd and decoding), "Decoding does not support backward"
@ -488,6 +549,7 @@ def generate_flash_configs(
dtype=dtype,
calculate_bwd_time=calculate_bwd,
cal_bandwidth=cal_bandwidth,
backend=backend,
)
)
@ -505,6 +567,7 @@ def generate_experiment_configs(
decoding: bool,
kv_cache_size: List[int],
cal_bandwidth: bool,
backend: str,
) -> List[ExperimentConfig]:
assert not (calculate_bwd and decoding), "Decoding does not support backward"
@ -555,6 +618,7 @@ def generate_experiment_configs(
dtype=dtype,
calculate_bwd_time=calculate_bwd,
cal_bandwidth=cal_bandwidth,
backend=backend,
)
)
@ -578,6 +642,7 @@ def main(args):
args.decoding,
args.kv_cache_size,
args.cal_bandwidth,
args.backend,
)
):
results.append(
@ -668,6 +733,13 @@ Ignores -b batch size and calculate batch size from kv_cache size instead when s
help="Path to save the results JSON file (optional)",
default=None,
)
parser.add_argument(
"--backend",
type=str,
choices=["default", "math", "efficient", "cudnn", "fav2", "fav3"],
default="default",
help="Backend to use for attention computation",
)
# Parse arguments
args = parser.parse_args()
args.dtype = getattr(torch, args.dtype)