mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[FlexAttention] Add mechanism to get optimal autotune decision
ghstack-source-id: 7dadc7fe8c9436c45fe2e8887d6a0b1b59610487 Pull-Request: https://github.com/pytorch/pytorch/pull/165817
This commit is contained in:
@ -2,8 +2,11 @@
|
|||||||
# flake8: noqa: B950
|
# flake8: noqa: B950
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import json
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
@ -6892,6 +6895,120 @@ class TestLearnableBiases(InductorTestCase):
|
|||||||
def test_flex_attention_with_dynamic_max_autotune_graph_partition(self, device):
|
def test_flex_attention_with_dynamic_max_autotune_graph_partition(self, device):
|
||||||
self._test_flex_attention_with_dynamic_max_autotune(device)
|
self._test_flex_attention_with_dynamic_max_autotune(device)
|
||||||
|
|
||||||
|
@skip_on_cpu
|
||||||
|
def test_flex_attention_logging(self, device):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
log_file = os.path.join(tmpdir, "flex_attention_configs")
|
||||||
|
|
||||||
|
with patch.dict(
|
||||||
|
os.environ, {"TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE": log_file}
|
||||||
|
):
|
||||||
|
query = torch.randn(
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
128,
|
||||||
|
64,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float16,
|
||||||
|
requires_grad=True,
|
||||||
|
)
|
||||||
|
key = torch.randn(
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
128,
|
||||||
|
64,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float16,
|
||||||
|
requires_grad=True,
|
||||||
|
)
|
||||||
|
value = torch.randn(
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
128,
|
||||||
|
64,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float16,
|
||||||
|
requires_grad=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def score_mod(score, b, h, q_idx, kv_idx):
|
||||||
|
return score * 2
|
||||||
|
|
||||||
|
def causal_mask(b, h, q_idx, kv_idx):
|
||||||
|
return q_idx >= kv_idx
|
||||||
|
|
||||||
|
block_mask = torch.compile(create_block_mask)(
|
||||||
|
causal_mask, 1, 1, 128, 128, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
compiled_flex = torch.compile(
|
||||||
|
flex_attention, mode="max-autotune-no-cudagraphs"
|
||||||
|
)
|
||||||
|
|
||||||
|
out = compiled_flex(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
score_mod=score_mod,
|
||||||
|
block_mask=block_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
out.sum().backward()
|
||||||
|
|
||||||
|
json_file = log_file + ".json"
|
||||||
|
self.assertTrue(
|
||||||
|
os.path.exists(json_file), f"Log file {json_file} was not created"
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(json_file) as f:
|
||||||
|
log_data = json.load(f)
|
||||||
|
|
||||||
|
self.assertIsInstance(log_data, list)
|
||||||
|
self.assertEqual(len(log_data), 2)
|
||||||
|
|
||||||
|
keys_seen = [next(iter(entry.keys())) for entry in log_data]
|
||||||
|
|
||||||
|
expected_fwd_key = "('forward', 1, 2, 2, 128, 128, 64, 64)"
|
||||||
|
expected_bwd_key = "('backward', 1, 2, 2, 128, 128, 64, 64)"
|
||||||
|
|
||||||
|
self.assertIn(expected_fwd_key, keys_seen)
|
||||||
|
self.assertIn(expected_bwd_key, keys_seen)
|
||||||
|
|
||||||
|
for entry in log_data:
|
||||||
|
self.assertIsInstance(entry, dict)
|
||||||
|
self.assertEqual(len(entry), 1)
|
||||||
|
|
||||||
|
dims_key = next(iter(entry.keys()))
|
||||||
|
choices = entry[dims_key]
|
||||||
|
|
||||||
|
kernel_type = eval(dims_key)[0]
|
||||||
|
|
||||||
|
self.assertIsInstance(choices, list)
|
||||||
|
self.assertGreater(len(choices), 0)
|
||||||
|
|
||||||
|
for i, choice in enumerate(choices):
|
||||||
|
self.assertIn("type", choice)
|
||||||
|
self.assertIn("time", choice)
|
||||||
|
|
||||||
|
if choice["type"] == "triton":
|
||||||
|
self.assertIn("num_warps", choice)
|
||||||
|
self.assertIn("num_stages", choice)
|
||||||
|
|
||||||
|
if kernel_type == "forward":
|
||||||
|
self.assertIn("BLOCK_M", choice)
|
||||||
|
self.assertIn("BLOCK_N", choice)
|
||||||
|
self.assertNotIn("BLOCK_M1", choice)
|
||||||
|
elif kernel_type == "backward":
|
||||||
|
self.assertIn("BLOCK_M1", choice)
|
||||||
|
self.assertIn("BLOCK_N1", choice)
|
||||||
|
self.assertIn("BLOCK_M2", choice)
|
||||||
|
self.assertIn("BLOCK_N2", choice)
|
||||||
|
self.assertNotIn("BLOCK_M", choice)
|
||||||
|
self.assertNotIn("BLOCK_N", choice)
|
||||||
|
|
||||||
|
if i > 0:
|
||||||
|
self.assertLessEqual(choices[0]["time"], choice["time"])
|
||||||
|
|
||||||
@skip_on_cpu
|
@skip_on_cpu
|
||||||
def test_inspect_bug(self, device):
|
def test_inspect_bug(self, device):
|
||||||
# https://github.com/pytorch/pytorch/issues/139374
|
# https://github.com/pytorch/pytorch/issues/139374
|
||||||
|
@ -17,6 +17,7 @@ import time
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from concurrent.futures import as_completed, ThreadPoolExecutor
|
from concurrent.futures import as_completed, ThreadPoolExecutor
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
from pathlib import Path
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
|
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
@ -2102,6 +2103,7 @@ class TritonTemplate(KernelTemplate):
|
|||||||
"matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0),
|
"matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0),
|
||||||
"waves_per_eu": kwargs.get("waves_per_eu", 0),
|
"waves_per_eu": kwargs.get("waves_per_eu", 0),
|
||||||
"kpack": kwargs.get("kpack", 2),
|
"kpack": kwargs.get("kpack", 2),
|
||||||
|
**{k: kwargs[k] for k in _FLEX_ATTENTION_TUNABLE_KEYS if k in kwargs},
|
||||||
},
|
},
|
||||||
mutated_inputs=mutated_inputs,
|
mutated_inputs=mutated_inputs,
|
||||||
workspace_arg=workspace_arg,
|
workspace_arg=workspace_arg,
|
||||||
@ -2395,6 +2397,35 @@ def get_mm_log_filename() -> Optional[str]:
|
|||||||
return mm_file_name
|
return mm_file_name
|
||||||
|
|
||||||
|
|
||||||
|
_FLEX_ATTENTION_TUNABLE_KEYS = frozenset(
|
||||||
|
[
|
||||||
|
"num_warps",
|
||||||
|
"num_stages",
|
||||||
|
"BLOCK_M",
|
||||||
|
"BLOCK_N",
|
||||||
|
"BLOCK_M1",
|
||||||
|
"BLOCK_N1",
|
||||||
|
"BLOCK_M2",
|
||||||
|
"BLOCK_N2",
|
||||||
|
"USE_TMA",
|
||||||
|
"kpack",
|
||||||
|
"matrix_instr_nonkdim",
|
||||||
|
"waves_per_eu",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def get_flex_attention_log_filename() -> Optional[str]:
|
||||||
|
flex_attention_file_name = os.environ.get(
|
||||||
|
"TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE", None
|
||||||
|
)
|
||||||
|
if not flex_attention_file_name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return str(Path(flex_attention_file_name).with_suffix(".json"))
|
||||||
|
|
||||||
|
|
||||||
def append_to_log(filename, data):
|
def append_to_log(filename, data):
|
||||||
lock_file = filename.replace(".json", ".lock")
|
lock_file = filename.replace(".json", ".lock")
|
||||||
lock = FileLock(lock_file)
|
lock = FileLock(lock_file)
|
||||||
@ -3500,6 +3531,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||||||
prescreening_elapse: Optional[float] = None,
|
prescreening_elapse: Optional[float] = None,
|
||||||
hint_override: Optional[int] = None,
|
hint_override: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
"""Log the autotuning results, currently only handles mm and flex"""
|
||||||
V.debug.log_autotuning_results(
|
V.debug.log_autotuning_results(
|
||||||
name, input_nodes, timings, elapse, precompile_elapse
|
name, input_nodes, timings, elapse, precompile_elapse
|
||||||
)
|
)
|
||||||
@ -3557,6 +3589,26 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||||||
"num_warps": info["num_warps"],
|
"num_warps": info["num_warps"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_flex_attention_choice_info(choice):
|
||||||
|
if isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller):
|
||||||
|
return {"type": "extern", "time": timings[choice]}
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
choice, torch._inductor.select_algorithm.TritonTemplateCaller
|
||||||
|
)
|
||||||
|
|
||||||
|
info = choice.info_dict()
|
||||||
|
result = {
|
||||||
|
"type": "triton",
|
||||||
|
"time": timings[choice],
|
||||||
|
}
|
||||||
|
|
||||||
|
for key in _FLEX_ATTENTION_TUNABLE_KEYS:
|
||||||
|
if key in info:
|
||||||
|
result[key] = info[key]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
mm_filename = get_mm_log_filename()
|
mm_filename = get_mm_log_filename()
|
||||||
if mm_filename and "mm" in name:
|
if mm_filename and "mm" in name:
|
||||||
M, K = input_nodes[-2].get_size()[:2]
|
M, K = input_nodes[-2].get_size()[:2]
|
||||||
@ -3568,6 +3620,44 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||||||
|
|
||||||
append_to_log(mm_filename, out_dict)
|
append_to_log(mm_filename, out_dict)
|
||||||
|
|
||||||
|
flex_attention_filename = get_flex_attention_log_filename()
|
||||||
|
if flex_attention_filename and "flex_attention" in name:
|
||||||
|
if len(input_nodes) >= 3:
|
||||||
|
query_size = input_nodes[0].get_size()
|
||||||
|
key_size = input_nodes[1].get_size()
|
||||||
|
value_size = input_nodes[2].get_size()
|
||||||
|
|
||||||
|
B = query_size[0]
|
||||||
|
Hq = query_size[1]
|
||||||
|
seq_len_q = query_size[2]
|
||||||
|
qk_head_dim = query_size[3]
|
||||||
|
Hkv = key_size[1]
|
||||||
|
seq_len_kv = key_size[2]
|
||||||
|
v_head_dim = value_size[3]
|
||||||
|
|
||||||
|
kernel_type = "backward" if "backward" in name else "forward"
|
||||||
|
dims_key = str(
|
||||||
|
(
|
||||||
|
kernel_type,
|
||||||
|
B,
|
||||||
|
Hq,
|
||||||
|
Hkv,
|
||||||
|
seq_len_q,
|
||||||
|
seq_len_kv,
|
||||||
|
qk_head_dim,
|
||||||
|
v_head_dim,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
sorted_choices = sorted(timings, key=timings.__getitem__)
|
||||||
|
out_dict = {
|
||||||
|
dims_key: [
|
||||||
|
get_flex_attention_choice_info(choice)
|
||||||
|
for choice in sorted_choices
|
||||||
|
]
|
||||||
|
}
|
||||||
|
append_to_log(flex_attention_filename, out_dict)
|
||||||
|
|
||||||
best_time = timings[best]
|
best_time = timings[best]
|
||||||
sys.stderr.write(f"AUTOTUNE {name}({sizes})\n")
|
sys.stderr.write(f"AUTOTUNE {name}({sizes})\n")
|
||||||
sys.stderr.write(f"strides: {strides}\n")
|
sys.stderr.write(f"strides: {strides}\n")
|
||||||
|
Reference in New Issue
Block a user