[FlexAttention] Add mechanism to get optimal autotune decision

ghstack-source-id: 7dadc7fe8c9436c45fe2e8887d6a0b1b59610487
Pull-Request: https://github.com/pytorch/pytorch/pull/165817
This commit is contained in:
drisspg
2025-10-19 17:58:34 +00:00
parent a317caf67e
commit 1db4025783
2 changed files with 207 additions and 0 deletions

View File

@ -2,8 +2,11 @@
# flake8: noqa: B950
import functools
import json
import os
import random
import string
import tempfile
import unittest
import warnings
from collections import namedtuple
@ -6892,6 +6895,120 @@ class TestLearnableBiases(InductorTestCase):
def test_flex_attention_with_dynamic_max_autotune_graph_partition(self, device):
self._test_flex_attention_with_dynamic_max_autotune(device)
@skip_on_cpu
def test_flex_attention_logging(self, device):
with tempfile.TemporaryDirectory() as tmpdir:
log_file = os.path.join(tmpdir, "flex_attention_configs")
with patch.dict(
os.environ, {"TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE": log_file}
):
query = torch.randn(
1,
2,
128,
64,
device=device,
dtype=torch.float16,
requires_grad=True,
)
key = torch.randn(
1,
2,
128,
64,
device=device,
dtype=torch.float16,
requires_grad=True,
)
value = torch.randn(
1,
2,
128,
64,
device=device,
dtype=torch.float16,
requires_grad=True,
)
def score_mod(score, b, h, q_idx, kv_idx):
return score * 2
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = torch.compile(create_block_mask)(
causal_mask, 1, 1, 128, 128, device=device
)
compiled_flex = torch.compile(
flex_attention, mode="max-autotune-no-cudagraphs"
)
out = compiled_flex(
query=query,
key=key,
value=value,
score_mod=score_mod,
block_mask=block_mask,
)
out.sum().backward()
json_file = log_file + ".json"
self.assertTrue(
os.path.exists(json_file), f"Log file {json_file} was not created"
)
with open(json_file) as f:
log_data = json.load(f)
self.assertIsInstance(log_data, list)
self.assertEqual(len(log_data), 2)
keys_seen = [next(iter(entry.keys())) for entry in log_data]
expected_fwd_key = "('forward', 1, 2, 2, 128, 128, 64, 64)"
expected_bwd_key = "('backward', 1, 2, 2, 128, 128, 64, 64)"
self.assertIn(expected_fwd_key, keys_seen)
self.assertIn(expected_bwd_key, keys_seen)
for entry in log_data:
self.assertIsInstance(entry, dict)
self.assertEqual(len(entry), 1)
dims_key = next(iter(entry.keys()))
choices = entry[dims_key]
kernel_type = eval(dims_key)[0]
self.assertIsInstance(choices, list)
self.assertGreater(len(choices), 0)
for i, choice in enumerate(choices):
self.assertIn("type", choice)
self.assertIn("time", choice)
if choice["type"] == "triton":
self.assertIn("num_warps", choice)
self.assertIn("num_stages", choice)
if kernel_type == "forward":
self.assertIn("BLOCK_M", choice)
self.assertIn("BLOCK_N", choice)
self.assertNotIn("BLOCK_M1", choice)
elif kernel_type == "backward":
self.assertIn("BLOCK_M1", choice)
self.assertIn("BLOCK_N1", choice)
self.assertIn("BLOCK_M2", choice)
self.assertIn("BLOCK_N2", choice)
self.assertNotIn("BLOCK_M", choice)
self.assertNotIn("BLOCK_N", choice)
if i > 0:
self.assertLessEqual(choices[0]["time"], choice["time"])
@skip_on_cpu
def test_inspect_bug(self, device):
# https://github.com/pytorch/pytorch/issues/139374

View File

@ -17,6 +17,7 @@ import time
from collections.abc import Sequence
from concurrent.futures import as_completed, ThreadPoolExecutor
from io import StringIO
from pathlib import Path
from types import ModuleType
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
from typing_extensions import Self
@ -2102,6 +2103,7 @@ class TritonTemplate(KernelTemplate):
"matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0),
"waves_per_eu": kwargs.get("waves_per_eu", 0),
"kpack": kwargs.get("kpack", 2),
**{k: kwargs[k] for k in _FLEX_ATTENTION_TUNABLE_KEYS if k in kwargs},
},
mutated_inputs=mutated_inputs,
workspace_arg=workspace_arg,
@ -2395,6 +2397,35 @@ def get_mm_log_filename() -> Optional[str]:
return mm_file_name
_FLEX_ATTENTION_TUNABLE_KEYS = frozenset(
[
"num_warps",
"num_stages",
"BLOCK_M",
"BLOCK_N",
"BLOCK_M1",
"BLOCK_N1",
"BLOCK_M2",
"BLOCK_N2",
"USE_TMA",
"kpack",
"matrix_instr_nonkdim",
"waves_per_eu",
]
)
@functools.cache
def get_flex_attention_log_filename() -> Optional[str]:
flex_attention_file_name = os.environ.get(
"TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE", None
)
if not flex_attention_file_name:
return None
return str(Path(flex_attention_file_name).with_suffix(".json"))
def append_to_log(filename, data):
lock_file = filename.replace(".json", ".lock")
lock = FileLock(lock_file)
@ -3500,6 +3531,7 @@ class AlgorithmSelectorCache(PersistentCache):
prescreening_elapse: Optional[float] = None,
hint_override: Optional[int] = None,
):
"""Log the autotuning results, currently only handles mm and flex"""
V.debug.log_autotuning_results(
name, input_nodes, timings, elapse, precompile_elapse
)
@ -3557,6 +3589,26 @@ class AlgorithmSelectorCache(PersistentCache):
"num_warps": info["num_warps"],
}
def get_flex_attention_choice_info(choice):
if isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller):
return {"type": "extern", "time": timings[choice]}
assert isinstance(
choice, torch._inductor.select_algorithm.TritonTemplateCaller
)
info = choice.info_dict()
result = {
"type": "triton",
"time": timings[choice],
}
for key in _FLEX_ATTENTION_TUNABLE_KEYS:
if key in info:
result[key] = info[key]
return result
mm_filename = get_mm_log_filename()
if mm_filename and "mm" in name:
M, K = input_nodes[-2].get_size()[:2]
@ -3568,6 +3620,44 @@ class AlgorithmSelectorCache(PersistentCache):
append_to_log(mm_filename, out_dict)
flex_attention_filename = get_flex_attention_log_filename()
if flex_attention_filename and "flex_attention" in name:
if len(input_nodes) >= 3:
query_size = input_nodes[0].get_size()
key_size = input_nodes[1].get_size()
value_size = input_nodes[2].get_size()
B = query_size[0]
Hq = query_size[1]
seq_len_q = query_size[2]
qk_head_dim = query_size[3]
Hkv = key_size[1]
seq_len_kv = key_size[2]
v_head_dim = value_size[3]
kernel_type = "backward" if "backward" in name else "forward"
dims_key = str(
(
kernel_type,
B,
Hq,
Hkv,
seq_len_q,
seq_len_kv,
qk_head_dim,
v_head_dim,
)
)
sorted_choices = sorted(timings, key=timings.__getitem__)
out_dict = {
dims_key: [
get_flex_attention_choice_info(choice)
for choice in sorted_choices
]
}
append_to_log(flex_attention_filename, out_dict)
best_time = timings[best]
sys.stderr.write(f"AUTOTUNE {name}({sizes})\n")
sys.stderr.write(f"strides: {strides}\n")