Revert "Improved flexattention bwd perf + added configurations for benchmarks (#129013)"

This reverts commit ff89ebc50a738c734496393dc25313cf197fd0b4.

Reverted https://github.com/pytorch/pytorch/pull/129013 on behalf of https://github.com/huydhn due to Sorry for reverting your change but one of the test_torchinductor_opinfo test starts to fail after this commit ff89ebc50a, I am reverting to see if it helps trunk recovers ([comment](https://github.com/pytorch/pytorch/pull/129013#issuecomment-2182042422))
This commit is contained in:
PyTorch MergeBot
2024-06-21 05:46:46 +00:00
parent b542825066
commit f73b451e78
5 changed files with 43 additions and 110 deletions

View File

@ -106,9 +106,7 @@ def generate_inputs(
return query, key, value
def run_single_experiment(
config: ExperimentConfig, dynamic=False, max_autotune=False
) -> ExperimentResults:
def run_single_experiment(config: ExperimentConfig, dynamic=False) -> ExperimentResults:
device = torch.device("cuda")
batch_size, num_heads, q_seq_len, head_dim = config.shape
query, key, value = generate_inputs(
@ -125,12 +123,7 @@ def run_single_experiment(
def eager_sdpa(query, key, value, _):
return F.scaled_dot_product_attention(query, key, value)
if max_autotune:
compiled_sdpa = torch.compile(
_flex_attention, dynamic=dynamic, mode="max-autotune-no-cudagraphs"
)
else:
compiled_sdpa = torch.compile(_flex_attention, dynamic=dynamic)
compiled_sdpa = torch.compile(_flex_attention, dynamic=dynamic)
score_mod = config.score_mod
@ -251,7 +244,7 @@ def print_results(results: List[Experiment]):
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
def generate_score_mods(score_mods: List[str]) -> List[Callable]:
def generate_score_mods() -> List[Callable]:
def noop(score, b, h, m, n):
return score
@ -264,27 +257,18 @@ def generate_score_mods(score_mods: List[str]) -> List[Callable]:
def head_bias(score, b, h, m, n):
return score + 2 * h
function_dict = {
"noop": noop,
"causal": causal_mask,
"rel": relative_bias,
"head_bias": head_bias,
}
return [function_dict[name] for name in score_mods]
return [noop, causal_mask, relative_bias, head_bias]
def generate_experiment_configs(
calculate_bwd: bool,
dtype: torch.dtype,
batch_sizes: List[int],
num_heads: List[int],
seq_lens: List[int],
head_dims: List[int],
score_mods: List[str],
) -> List[ExperimentConfig]:
q_kv_seq_lens = [(i, i) for i in seq_lens] # only testing q_len == kv_len
dtypes = [dtype]
score_mods = generate_score_mods(score_mods)
def generate_experiment_configs(calculate_bwd: bool) -> List[ExperimentConfig]:
batch_sizes = [2, 8, 16]
num_heads = [16]
q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)]
head_dims = [64, 128]
dtypes = [
torch.bfloat16,
]
score_mods = generate_score_mods()
all_configs = []
for (
bsz,
@ -309,23 +293,14 @@ def generate_experiment_configs(
return all_configs
def main(args):
def main(dynamic: bool, calculate_bwd: bool):
seed = 123
np.random.seed(seed)
torch.manual_seed(seed)
results = []
for config in tqdm(
generate_experiment_configs(
args.calculate_bwd, args.dtype, args.b, args.nh, args.s, args.d, args.mods
)
):
for config in tqdm(generate_experiment_configs(calculate_bwd)):
results.append(
Experiment(
config,
run_single_experiment(
config, dynamic=args.dynamic, max_autotune=args.max_autotune
),
)
Experiment(config, run_single_experiment(config, dynamic=dynamic))
)
print_results(results)
@ -345,29 +320,7 @@ if __name__ == "__main__":
"--calculate-bwd", action="store_true", help="Calculate backward pass times"
)
parser.add_argument("-dtype", type=str, help="dtype", default="bfloat16")
parser.add_argument(
"-b", type=int, nargs="+", help="batch sizes", default=[2, 8, 16]
)
parser.add_argument("-nh", type=int, nargs="+", help="# of heads", default=[16])
parser.add_argument(
"-s", type=int, nargs="+", help="sequence lengths", default=[512, 1024, 4096]
)
parser.add_argument("-d", type=int, nargs="+", help="head dims", default=[64, 128])
parser.add_argument(
"-mods",
type=str,
nargs="+",
help="score mods",
default=["noop", "causal", "rel", "head_bias"],
)
parser.add_argument(
"--max-autotune", action="store_true", help="Turn on max-autotune"
)
# Parse arguments
args = parser.parse_args()
args.dtype = getattr(torch, args.dtype)
main(args)
main(args.dynamic, args.calculate_bwd)

View File

@ -32,7 +32,6 @@ from torch.utils._triton import has_triton
# Skip tests if Triton is not available
supported_platform = skipUnless(
torch.cuda.is_available()
and torch.version.hip is None
and has_triton()
and torch.cuda.get_device_capability() >= (8, 0),
"Requires CUDA and Triton",

View File

@ -277,7 +277,7 @@ def flex_attention_fake_tensor_mode(
logsumexp = query.new_empty(
batch_size, num_heads, seq_len_q, dtype=torch.float32
)
return torch.empty_like(query), logsumexp
return torch.empty_like(query, memory_format=torch.contiguous_format), logsumexp
# ---------------------------- Autograd Implementation ----------------------------
@ -670,9 +670,9 @@ def flex_attention_backward_fake_tensor_mode(
*other_buffers: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
with mode:
grad_query = torch.empty_like(query)
grad_key = torch.empty_like(key)
grad_value = torch.empty_like(value)
grad_query = torch.empty_like(query, memory_format=torch.contiguous_format)
grad_key = torch.empty_like(key, memory_format=torch.contiguous_format)
grad_value = torch.empty_like(value, memory_format=torch.contiguous_format)
return grad_query, grad_key, grad_value

View File

@ -17,7 +17,7 @@ from ..ir import (
Subgraph,
TensorBox,
)
from ..lowering import empty_strided, lowerings, register_lowering
from ..lowering import empty_strided, full, lowerings, register_lowering
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
log = logging.getLogger(__name__)
@ -314,9 +314,6 @@ _h100_default_config = {
(torch.bfloat16, 64): (128, 64, 4, 3),
(torch.bfloat16, 128): (64, 32, 4, 3),
(torch.bfloat16, 256): (64, 32, 4, 3),
(torch.float16, 64): (128, 64, 4, 3),
(torch.float16, 128): (64, 32, 4, 3),
(torch.float16, 256): (64, 32, 4, 3),
}
_a100_default_config = {
@ -324,11 +321,8 @@ _a100_default_config = {
(torch.float32, 128): (128, 32, 4, 3),
(torch.float32, 256): (64, 16, 4, 3),
(torch.bfloat16, 64): (128, 64, 4, 3),
(torch.bfloat16, 128): (128, 128, 8, 2),
(torch.bfloat16, 128): (128, 32, 4, 3),
(torch.bfloat16, 256): (32, 64, 4, 3),
(torch.float16, 64): (128, 64, 4, 3),
(torch.float16, 128): (128, 128, 8, 2),
(torch.float16, 256): (32, 64, 4, 3),
}
@ -362,17 +356,12 @@ def _get_default_config_bwd(query) -> Tuple[int, int, int, int]:
head_dim = query.get_size()[-1]
dtype = query.get_dtype()
if dtype == torch.float32:
return (16, 16, 4, 1)
if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100
return (32, 128, 4, 3)
elif torch.cuda.get_device_capability() >= (8, 0): # A100
if head_dim == 64:
return (32, 128, 4, 3)
elif head_dim == 128:
return (64, 128, 8, 3)
else:
return (64, 64, 4, 2)
if dtype == torch.float32:
return (64, 64, 4, 1)
return (128, 128, 4, 3)
elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0): # A100
return (64, 64, 4, 1)
else: # modest hardware or extremely large head_dim
return (16, 16, 4, 1)
@ -400,7 +389,7 @@ def flex_attention(*args, **kwargs):
query.get_device(),
query.get_dtype(),
query.get_size(),
query.get_stride(),
FlexibleLayout.contiguous_strides(query.get_size()),
)
# see NOTE:[TritonTemplates with multiple outputs]
logsumexp_shape = query.get_size()[:-1] # [B, H, M]
@ -572,7 +561,7 @@ flex_attention_backward_template = TritonTemplate(
curr_n = start_n2
num_steps = KV_LEN // BLOCK_N2
for blk_idx in range(num_steps):
offs_n2 = curr_n + tl.arange(0, BLOCK_N2)
offs_n2= curr_n + tl.arange(0, BLOCK_N2)
kT = tl.load(kT_ptrs)
vT = tl.load(vT_ptrs)
qk = tl.dot(q, kT)
@ -702,8 +691,8 @@ flex_attention_backward_template = TritonTemplate(
# Write back dK.
index_n = offs_n1[:, None]
index_k = offs_k[None, :]
mask = index_n <= KV_LEN
# TODO generalize and add proper mask support
mask = (index_n != -1) & (index_k != -1)
{{store_output(("off_z", "off_h", "index_n", "index_k"), "dk", "mask", indent_width=8)}}
""",
)
@ -756,7 +745,7 @@ def flex_attention_backward(*args, **kwargs):
key.get_device(),
key.get_dtype(),
key.get_size(),
key.get_stride(),
FlexibleLayout.contiguous_strides(key.get_size()),
)
# Create delta which will is needed for the bwd's kernel
@ -764,23 +753,20 @@ def flex_attention_backward(*args, **kwargs):
delta = lowerings[aten.sum](mul_delta, axis=-1)
# see NOTE:[TritonTemplates with multiple outputs]
grad_query = empty_strided(
query.get_size(), query.get_stride(), dtype=dtype, device=device
)
grad_value = empty_strided(
value.get_size(), value.get_stride(), dtype=dtype, device=device
)
grad_query = full(
query.get_size(), 0.0, dtype=dtype, device=device
) # torch.zeros equivalent
grad_query.realize()
grad_value = empty_strided(value.get_size(), None, dtype=dtype, device=device)
choices: List[Any] = []
configs: List[Tuple[int, int, int, int]] = []
configs.append(_get_default_config_bwd(query))
if config.max_autotune:
for BLOCK1 in [32, 64]:
for BLOCK2 in [32, 64, 128]:
if BLOCK2 % BLOCK1 != 0:
continue
for BLOCK2 in [32, 64]:
for w in [4, 8]:
for s in [1, 3, 4, 5]:
for s in [1, 3]:
configs.append((BLOCK1, BLOCK2, w, s))
for BLOCK1, BLOCK2, num_warps, num_stages in configs:
@ -804,9 +790,9 @@ def flex_attention_backward(*args, **kwargs):
num_stages=num_stages,
num_warps=num_warps,
BLOCK_M1=BLOCK1,
BLOCK_N1=BLOCK2,
BLOCK_N1=BLOCK1,
BLOCK_M2=BLOCK2,
BLOCK_N2=BLOCK1,
BLOCK_N2=BLOCK2,
BLOCK_DMODEL=query.get_size()[-1],
# For now, we always assume the "sound" option
SCORE_MOD_IS_LINEAR=False,

View File

@ -102,8 +102,6 @@ class PartialRender:
return self.code
# This is used to store info needed for lowering each subgraph in triton
# templates
SubgraphInfo = namedtuple(
"SubgraphInfo",
[
@ -1578,11 +1576,8 @@ class AlgorithmSelectorCache(PersistentCache):
for choice in top_k:
result = timings[choice]
if result:
kernel_info = (
choice.debug_extra if hasattr(choice, "debug_extra") else ""
)
sys.stderr.write(
f" {choice.name} {result:.4f} ms {best_time / result:.1%} {kernel_info}\n"
f" {choice.name} {result:.4f} ms {best_time / result:.1%}\n"
)
else:
sys.stderr.write(