mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[FlexAttention] Add mechanism to get optimal autotune decision
ghstack-source-id: 7dadc7fe8c9436c45fe2e8887d6a0b1b59610487 Pull-Request: https://github.com/pytorch/pytorch/pull/165817
This commit is contained in:
@ -2,8 +2,11 @@
|
||||
# flake8: noqa: B950
|
||||
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import tempfile
|
||||
import unittest
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
@ -6892,6 +6895,120 @@ class TestLearnableBiases(InductorTestCase):
|
||||
def test_flex_attention_with_dynamic_max_autotune_graph_partition(self, device):
|
||||
self._test_flex_attention_with_dynamic_max_autotune(device)
|
||||
|
||||
@skip_on_cpu
|
||||
def test_flex_attention_logging(self, device):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
log_file = os.path.join(tmpdir, "flex_attention_configs")
|
||||
|
||||
with patch.dict(
|
||||
os.environ, {"TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE": log_file}
|
||||
):
|
||||
query = torch.randn(
|
||||
1,
|
||||
2,
|
||||
128,
|
||||
64,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
requires_grad=True,
|
||||
)
|
||||
key = torch.randn(
|
||||
1,
|
||||
2,
|
||||
128,
|
||||
64,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
requires_grad=True,
|
||||
)
|
||||
value = torch.randn(
|
||||
1,
|
||||
2,
|
||||
128,
|
||||
64,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
def score_mod(score, b, h, q_idx, kv_idx):
|
||||
return score * 2
|
||||
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
block_mask = torch.compile(create_block_mask)(
|
||||
causal_mask, 1, 1, 128, 128, device=device
|
||||
)
|
||||
|
||||
compiled_flex = torch.compile(
|
||||
flex_attention, mode="max-autotune-no-cudagraphs"
|
||||
)
|
||||
|
||||
out = compiled_flex(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
score_mod=score_mod,
|
||||
block_mask=block_mask,
|
||||
)
|
||||
|
||||
out.sum().backward()
|
||||
|
||||
json_file = log_file + ".json"
|
||||
self.assertTrue(
|
||||
os.path.exists(json_file), f"Log file {json_file} was not created"
|
||||
)
|
||||
|
||||
with open(json_file) as f:
|
||||
log_data = json.load(f)
|
||||
|
||||
self.assertIsInstance(log_data, list)
|
||||
self.assertEqual(len(log_data), 2)
|
||||
|
||||
keys_seen = [next(iter(entry.keys())) for entry in log_data]
|
||||
|
||||
expected_fwd_key = "('forward', 1, 2, 2, 128, 128, 64, 64)"
|
||||
expected_bwd_key = "('backward', 1, 2, 2, 128, 128, 64, 64)"
|
||||
|
||||
self.assertIn(expected_fwd_key, keys_seen)
|
||||
self.assertIn(expected_bwd_key, keys_seen)
|
||||
|
||||
for entry in log_data:
|
||||
self.assertIsInstance(entry, dict)
|
||||
self.assertEqual(len(entry), 1)
|
||||
|
||||
dims_key = next(iter(entry.keys()))
|
||||
choices = entry[dims_key]
|
||||
|
||||
kernel_type = eval(dims_key)[0]
|
||||
|
||||
self.assertIsInstance(choices, list)
|
||||
self.assertGreater(len(choices), 0)
|
||||
|
||||
for i, choice in enumerate(choices):
|
||||
self.assertIn("type", choice)
|
||||
self.assertIn("time", choice)
|
||||
|
||||
if choice["type"] == "triton":
|
||||
self.assertIn("num_warps", choice)
|
||||
self.assertIn("num_stages", choice)
|
||||
|
||||
if kernel_type == "forward":
|
||||
self.assertIn("BLOCK_M", choice)
|
||||
self.assertIn("BLOCK_N", choice)
|
||||
self.assertNotIn("BLOCK_M1", choice)
|
||||
elif kernel_type == "backward":
|
||||
self.assertIn("BLOCK_M1", choice)
|
||||
self.assertIn("BLOCK_N1", choice)
|
||||
self.assertIn("BLOCK_M2", choice)
|
||||
self.assertIn("BLOCK_N2", choice)
|
||||
self.assertNotIn("BLOCK_M", choice)
|
||||
self.assertNotIn("BLOCK_N", choice)
|
||||
|
||||
if i > 0:
|
||||
self.assertLessEqual(choices[0]["time"], choice["time"])
|
||||
|
||||
@skip_on_cpu
|
||||
def test_inspect_bug(self, device):
|
||||
# https://github.com/pytorch/pytorch/issues/139374
|
||||
|
@ -17,6 +17,7 @@ import time
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import as_completed, ThreadPoolExecutor
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import Self
|
||||
@ -2102,6 +2103,7 @@ class TritonTemplate(KernelTemplate):
|
||||
"matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0),
|
||||
"waves_per_eu": kwargs.get("waves_per_eu", 0),
|
||||
"kpack": kwargs.get("kpack", 2),
|
||||
**{k: kwargs[k] for k in _FLEX_ATTENTION_TUNABLE_KEYS if k in kwargs},
|
||||
},
|
||||
mutated_inputs=mutated_inputs,
|
||||
workspace_arg=workspace_arg,
|
||||
@ -2395,6 +2397,35 @@ def get_mm_log_filename() -> Optional[str]:
|
||||
return mm_file_name
|
||||
|
||||
|
||||
_FLEX_ATTENTION_TUNABLE_KEYS = frozenset(
|
||||
[
|
||||
"num_warps",
|
||||
"num_stages",
|
||||
"BLOCK_M",
|
||||
"BLOCK_N",
|
||||
"BLOCK_M1",
|
||||
"BLOCK_N1",
|
||||
"BLOCK_M2",
|
||||
"BLOCK_N2",
|
||||
"USE_TMA",
|
||||
"kpack",
|
||||
"matrix_instr_nonkdim",
|
||||
"waves_per_eu",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_flex_attention_log_filename() -> Optional[str]:
|
||||
flex_attention_file_name = os.environ.get(
|
||||
"TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE", None
|
||||
)
|
||||
if not flex_attention_file_name:
|
||||
return None
|
||||
|
||||
return str(Path(flex_attention_file_name).with_suffix(".json"))
|
||||
|
||||
|
||||
def append_to_log(filename, data):
|
||||
lock_file = filename.replace(".json", ".lock")
|
||||
lock = FileLock(lock_file)
|
||||
@ -3500,6 +3531,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
prescreening_elapse: Optional[float] = None,
|
||||
hint_override: Optional[int] = None,
|
||||
):
|
||||
"""Log the autotuning results, currently only handles mm and flex"""
|
||||
V.debug.log_autotuning_results(
|
||||
name, input_nodes, timings, elapse, precompile_elapse
|
||||
)
|
||||
@ -3557,6 +3589,26 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
"num_warps": info["num_warps"],
|
||||
}
|
||||
|
||||
def get_flex_attention_choice_info(choice):
|
||||
if isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller):
|
||||
return {"type": "extern", "time": timings[choice]}
|
||||
|
||||
assert isinstance(
|
||||
choice, torch._inductor.select_algorithm.TritonTemplateCaller
|
||||
)
|
||||
|
||||
info = choice.info_dict()
|
||||
result = {
|
||||
"type": "triton",
|
||||
"time": timings[choice],
|
||||
}
|
||||
|
||||
for key in _FLEX_ATTENTION_TUNABLE_KEYS:
|
||||
if key in info:
|
||||
result[key] = info[key]
|
||||
|
||||
return result
|
||||
|
||||
mm_filename = get_mm_log_filename()
|
||||
if mm_filename and "mm" in name:
|
||||
M, K = input_nodes[-2].get_size()[:2]
|
||||
@ -3568,6 +3620,44 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
|
||||
append_to_log(mm_filename, out_dict)
|
||||
|
||||
flex_attention_filename = get_flex_attention_log_filename()
|
||||
if flex_attention_filename and "flex_attention" in name:
|
||||
if len(input_nodes) >= 3:
|
||||
query_size = input_nodes[0].get_size()
|
||||
key_size = input_nodes[1].get_size()
|
||||
value_size = input_nodes[2].get_size()
|
||||
|
||||
B = query_size[0]
|
||||
Hq = query_size[1]
|
||||
seq_len_q = query_size[2]
|
||||
qk_head_dim = query_size[3]
|
||||
Hkv = key_size[1]
|
||||
seq_len_kv = key_size[2]
|
||||
v_head_dim = value_size[3]
|
||||
|
||||
kernel_type = "backward" if "backward" in name else "forward"
|
||||
dims_key = str(
|
||||
(
|
||||
kernel_type,
|
||||
B,
|
||||
Hq,
|
||||
Hkv,
|
||||
seq_len_q,
|
||||
seq_len_kv,
|
||||
qk_head_dim,
|
||||
v_head_dim,
|
||||
)
|
||||
)
|
||||
|
||||
sorted_choices = sorted(timings, key=timings.__getitem__)
|
||||
out_dict = {
|
||||
dims_key: [
|
||||
get_flex_attention_choice_info(choice)
|
||||
for choice in sorted_choices
|
||||
]
|
||||
}
|
||||
append_to_log(flex_attention_filename, out_dict)
|
||||
|
||||
best_time = timings[best]
|
||||
sys.stderr.write(f"AUTOTUNE {name}({sizes})\n")
|
||||
sys.stderr.write(f"strides: {strides}\n")
|
||||
|
Reference in New Issue
Block a user