[FlexAttention] Add mechanism to get optimal autotune decision

ghstack-source-id: 7dadc7fe8c9436c45fe2e8887d6a0b1b59610487
Pull-Request: https://github.com/pytorch/pytorch/pull/165817
This commit is contained in:
drisspg
2025-10-19 17:58:34 +00:00
parent a317caf67e
commit 1db4025783
2 changed files with 207 additions and 0 deletions

View File

@ -2,8 +2,11 @@
# flake8: noqa: B950 # flake8: noqa: B950
import functools import functools
import json
import os
import random import random
import string import string
import tempfile
import unittest import unittest
import warnings import warnings
from collections import namedtuple from collections import namedtuple
@ -6892,6 +6895,120 @@ class TestLearnableBiases(InductorTestCase):
def test_flex_attention_with_dynamic_max_autotune_graph_partition(self, device): def test_flex_attention_with_dynamic_max_autotune_graph_partition(self, device):
self._test_flex_attention_with_dynamic_max_autotune(device) self._test_flex_attention_with_dynamic_max_autotune(device)
@skip_on_cpu
def test_flex_attention_logging(self, device):
with tempfile.TemporaryDirectory() as tmpdir:
log_file = os.path.join(tmpdir, "flex_attention_configs")
with patch.dict(
os.environ, {"TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE": log_file}
):
query = torch.randn(
1,
2,
128,
64,
device=device,
dtype=torch.float16,
requires_grad=True,
)
key = torch.randn(
1,
2,
128,
64,
device=device,
dtype=torch.float16,
requires_grad=True,
)
value = torch.randn(
1,
2,
128,
64,
device=device,
dtype=torch.float16,
requires_grad=True,
)
def score_mod(score, b, h, q_idx, kv_idx):
return score * 2
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = torch.compile(create_block_mask)(
causal_mask, 1, 1, 128, 128, device=device
)
compiled_flex = torch.compile(
flex_attention, mode="max-autotune-no-cudagraphs"
)
out = compiled_flex(
query=query,
key=key,
value=value,
score_mod=score_mod,
block_mask=block_mask,
)
out.sum().backward()
json_file = log_file + ".json"
self.assertTrue(
os.path.exists(json_file), f"Log file {json_file} was not created"
)
with open(json_file) as f:
log_data = json.load(f)
self.assertIsInstance(log_data, list)
self.assertEqual(len(log_data), 2)
keys_seen = [next(iter(entry.keys())) for entry in log_data]
expected_fwd_key = "('forward', 1, 2, 2, 128, 128, 64, 64)"
expected_bwd_key = "('backward', 1, 2, 2, 128, 128, 64, 64)"
self.assertIn(expected_fwd_key, keys_seen)
self.assertIn(expected_bwd_key, keys_seen)
for entry in log_data:
self.assertIsInstance(entry, dict)
self.assertEqual(len(entry), 1)
dims_key = next(iter(entry.keys()))
choices = entry[dims_key]
kernel_type = eval(dims_key)[0]
self.assertIsInstance(choices, list)
self.assertGreater(len(choices), 0)
for i, choice in enumerate(choices):
self.assertIn("type", choice)
self.assertIn("time", choice)
if choice["type"] == "triton":
self.assertIn("num_warps", choice)
self.assertIn("num_stages", choice)
if kernel_type == "forward":
self.assertIn("BLOCK_M", choice)
self.assertIn("BLOCK_N", choice)
self.assertNotIn("BLOCK_M1", choice)
elif kernel_type == "backward":
self.assertIn("BLOCK_M1", choice)
self.assertIn("BLOCK_N1", choice)
self.assertIn("BLOCK_M2", choice)
self.assertIn("BLOCK_N2", choice)
self.assertNotIn("BLOCK_M", choice)
self.assertNotIn("BLOCK_N", choice)
if i > 0:
self.assertLessEqual(choices[0]["time"], choice["time"])
@skip_on_cpu @skip_on_cpu
def test_inspect_bug(self, device): def test_inspect_bug(self, device):
# https://github.com/pytorch/pytorch/issues/139374 # https://github.com/pytorch/pytorch/issues/139374

View File

@ -17,6 +17,7 @@ import time
from collections.abc import Sequence from collections.abc import Sequence
from concurrent.futures import as_completed, ThreadPoolExecutor from concurrent.futures import as_completed, ThreadPoolExecutor
from io import StringIO from io import StringIO
from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
from typing_extensions import Self from typing_extensions import Self
@ -2102,6 +2103,7 @@ class TritonTemplate(KernelTemplate):
"matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0), "matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0),
"waves_per_eu": kwargs.get("waves_per_eu", 0), "waves_per_eu": kwargs.get("waves_per_eu", 0),
"kpack": kwargs.get("kpack", 2), "kpack": kwargs.get("kpack", 2),
**{k: kwargs[k] for k in _FLEX_ATTENTION_TUNABLE_KEYS if k in kwargs},
}, },
mutated_inputs=mutated_inputs, mutated_inputs=mutated_inputs,
workspace_arg=workspace_arg, workspace_arg=workspace_arg,
@ -2395,6 +2397,35 @@ def get_mm_log_filename() -> Optional[str]:
return mm_file_name return mm_file_name
_FLEX_ATTENTION_TUNABLE_KEYS = frozenset(
[
"num_warps",
"num_stages",
"BLOCK_M",
"BLOCK_N",
"BLOCK_M1",
"BLOCK_N1",
"BLOCK_M2",
"BLOCK_N2",
"USE_TMA",
"kpack",
"matrix_instr_nonkdim",
"waves_per_eu",
]
)
@functools.cache
def get_flex_attention_log_filename() -> Optional[str]:
flex_attention_file_name = os.environ.get(
"TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE", None
)
if not flex_attention_file_name:
return None
return str(Path(flex_attention_file_name).with_suffix(".json"))
def append_to_log(filename, data): def append_to_log(filename, data):
lock_file = filename.replace(".json", ".lock") lock_file = filename.replace(".json", ".lock")
lock = FileLock(lock_file) lock = FileLock(lock_file)
@ -3500,6 +3531,7 @@ class AlgorithmSelectorCache(PersistentCache):
prescreening_elapse: Optional[float] = None, prescreening_elapse: Optional[float] = None,
hint_override: Optional[int] = None, hint_override: Optional[int] = None,
): ):
"""Log the autotuning results, currently only handles mm and flex"""
V.debug.log_autotuning_results( V.debug.log_autotuning_results(
name, input_nodes, timings, elapse, precompile_elapse name, input_nodes, timings, elapse, precompile_elapse
) )
@ -3557,6 +3589,26 @@ class AlgorithmSelectorCache(PersistentCache):
"num_warps": info["num_warps"], "num_warps": info["num_warps"],
} }
def get_flex_attention_choice_info(choice):
if isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller):
return {"type": "extern", "time": timings[choice]}
assert isinstance(
choice, torch._inductor.select_algorithm.TritonTemplateCaller
)
info = choice.info_dict()
result = {
"type": "triton",
"time": timings[choice],
}
for key in _FLEX_ATTENTION_TUNABLE_KEYS:
if key in info:
result[key] = info[key]
return result
mm_filename = get_mm_log_filename() mm_filename = get_mm_log_filename()
if mm_filename and "mm" in name: if mm_filename and "mm" in name:
M, K = input_nodes[-2].get_size()[:2] M, K = input_nodes[-2].get_size()[:2]
@ -3568,6 +3620,44 @@ class AlgorithmSelectorCache(PersistentCache):
append_to_log(mm_filename, out_dict) append_to_log(mm_filename, out_dict)
flex_attention_filename = get_flex_attention_log_filename()
if flex_attention_filename and "flex_attention" in name:
if len(input_nodes) >= 3:
query_size = input_nodes[0].get_size()
key_size = input_nodes[1].get_size()
value_size = input_nodes[2].get_size()
B = query_size[0]
Hq = query_size[1]
seq_len_q = query_size[2]
qk_head_dim = query_size[3]
Hkv = key_size[1]
seq_len_kv = key_size[2]
v_head_dim = value_size[3]
kernel_type = "backward" if "backward" in name else "forward"
dims_key = str(
(
kernel_type,
B,
Hq,
Hkv,
seq_len_q,
seq_len_kv,
qk_head_dim,
v_head_dim,
)
)
sorted_choices = sorted(timings, key=timings.__getitem__)
out_dict = {
dims_key: [
get_flex_attention_choice_info(choice)
for choice in sorted_choices
]
}
append_to_log(flex_attention_filename, out_dict)
best_time = timings[best] best_time = timings[best]
sys.stderr.write(f"AUTOTUNE {name}({sizes})\n") sys.stderr.write(f"AUTOTUNE {name}({sizes})\n")
sys.stderr.write(f"strides: {strides}\n") sys.stderr.write(f"strides: {strides}\n")