diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 1e631b4af389..3ce3143a29fe 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -2,8 +2,11 @@ # flake8: noqa: B950 import functools +import json +import os import random import string +import tempfile import unittest import warnings from collections import namedtuple @@ -6892,6 +6895,120 @@ class TestLearnableBiases(InductorTestCase): def test_flex_attention_with_dynamic_max_autotune_graph_partition(self, device): self._test_flex_attention_with_dynamic_max_autotune(device) + @skip_on_cpu + def test_flex_attention_logging(self, device): + with tempfile.TemporaryDirectory() as tmpdir: + log_file = os.path.join(tmpdir, "flex_attention_configs") + + with patch.dict( + os.environ, {"TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE": log_file} + ): + query = torch.randn( + 1, + 2, + 128, + 64, + device=device, + dtype=torch.float16, + requires_grad=True, + ) + key = torch.randn( + 1, + 2, + 128, + 64, + device=device, + dtype=torch.float16, + requires_grad=True, + ) + value = torch.randn( + 1, + 2, + 128, + 64, + device=device, + dtype=torch.float16, + requires_grad=True, + ) + + def score_mod(score, b, h, q_idx, kv_idx): + return score * 2 + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + block_mask = torch.compile(create_block_mask)( + causal_mask, 1, 1, 128, 128, device=device + ) + + compiled_flex = torch.compile( + flex_attention, mode="max-autotune-no-cudagraphs" + ) + + out = compiled_flex( + query=query, + key=key, + value=value, + score_mod=score_mod, + block_mask=block_mask, + ) + + out.sum().backward() + + json_file = log_file + ".json" + self.assertTrue( + os.path.exists(json_file), f"Log file {json_file} was not created" + ) + + with open(json_file) as f: + log_data = json.load(f) + + self.assertIsInstance(log_data, list) + self.assertEqual(len(log_data), 2) + + keys_seen = [next(iter(entry.keys())) for entry in log_data] + + expected_fwd_key = "('forward', 1, 2, 2, 128, 128, 64, 64)" + expected_bwd_key = "('backward', 1, 2, 2, 128, 128, 64, 64)" + + self.assertIn(expected_fwd_key, keys_seen) + self.assertIn(expected_bwd_key, keys_seen) + + for entry in log_data: + self.assertIsInstance(entry, dict) + self.assertEqual(len(entry), 1) + + dims_key = next(iter(entry.keys())) + choices = entry[dims_key] + + kernel_type = eval(dims_key)[0] + + self.assertIsInstance(choices, list) + self.assertGreater(len(choices), 0) + + for i, choice in enumerate(choices): + self.assertIn("type", choice) + self.assertIn("time", choice) + + if choice["type"] == "triton": + self.assertIn("num_warps", choice) + self.assertIn("num_stages", choice) + + if kernel_type == "forward": + self.assertIn("BLOCK_M", choice) + self.assertIn("BLOCK_N", choice) + self.assertNotIn("BLOCK_M1", choice) + elif kernel_type == "backward": + self.assertIn("BLOCK_M1", choice) + self.assertIn("BLOCK_N1", choice) + self.assertIn("BLOCK_M2", choice) + self.assertIn("BLOCK_N2", choice) + self.assertNotIn("BLOCK_M", choice) + self.assertNotIn("BLOCK_N", choice) + + if i > 0: + self.assertLessEqual(choices[0]["time"], choice["time"]) + @skip_on_cpu def test_inspect_bug(self, device): # https://github.com/pytorch/pytorch/issues/139374 diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index b0e81444ad84..ca58c391da6d 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -17,6 +17,7 @@ import time from collections.abc import Sequence from concurrent.futures import as_completed, ThreadPoolExecutor from io import StringIO +from pathlib import Path from types import ModuleType from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union from typing_extensions import Self @@ -2102,6 +2103,7 @@ class TritonTemplate(KernelTemplate): "matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0), "waves_per_eu": kwargs.get("waves_per_eu", 0), "kpack": kwargs.get("kpack", 2), + **{k: kwargs[k] for k in _FLEX_ATTENTION_TUNABLE_KEYS if k in kwargs}, }, mutated_inputs=mutated_inputs, workspace_arg=workspace_arg, @@ -2395,6 +2397,35 @@ def get_mm_log_filename() -> Optional[str]: return mm_file_name +_FLEX_ATTENTION_TUNABLE_KEYS = frozenset( + [ + "num_warps", + "num_stages", + "BLOCK_M", + "BLOCK_N", + "BLOCK_M1", + "BLOCK_N1", + "BLOCK_M2", + "BLOCK_N2", + "USE_TMA", + "kpack", + "matrix_instr_nonkdim", + "waves_per_eu", + ] +) + + +@functools.cache +def get_flex_attention_log_filename() -> Optional[str]: + flex_attention_file_name = os.environ.get( + "TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE", None + ) + if not flex_attention_file_name: + return None + + return str(Path(flex_attention_file_name).with_suffix(".json")) + + def append_to_log(filename, data): lock_file = filename.replace(".json", ".lock") lock = FileLock(lock_file) @@ -3500,6 +3531,7 @@ class AlgorithmSelectorCache(PersistentCache): prescreening_elapse: Optional[float] = None, hint_override: Optional[int] = None, ): + """Log the autotuning results, currently only handles mm and flex""" V.debug.log_autotuning_results( name, input_nodes, timings, elapse, precompile_elapse ) @@ -3557,6 +3589,26 @@ class AlgorithmSelectorCache(PersistentCache): "num_warps": info["num_warps"], } + def get_flex_attention_choice_info(choice): + if isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller): + return {"type": "extern", "time": timings[choice]} + + assert isinstance( + choice, torch._inductor.select_algorithm.TritonTemplateCaller + ) + + info = choice.info_dict() + result = { + "type": "triton", + "time": timings[choice], + } + + for key in _FLEX_ATTENTION_TUNABLE_KEYS: + if key in info: + result[key] = info[key] + + return result + mm_filename = get_mm_log_filename() if mm_filename and "mm" in name: M, K = input_nodes[-2].get_size()[:2] @@ -3568,6 +3620,44 @@ class AlgorithmSelectorCache(PersistentCache): append_to_log(mm_filename, out_dict) + flex_attention_filename = get_flex_attention_log_filename() + if flex_attention_filename and "flex_attention" in name: + if len(input_nodes) >= 3: + query_size = input_nodes[0].get_size() + key_size = input_nodes[1].get_size() + value_size = input_nodes[2].get_size() + + B = query_size[0] + Hq = query_size[1] + seq_len_q = query_size[2] + qk_head_dim = query_size[3] + Hkv = key_size[1] + seq_len_kv = key_size[2] + v_head_dim = value_size[3] + + kernel_type = "backward" if "backward" in name else "forward" + dims_key = str( + ( + kernel_type, + B, + Hq, + Hkv, + seq_len_q, + seq_len_kv, + qk_head_dim, + v_head_dim, + ) + ) + + sorted_choices = sorted(timings, key=timings.__getitem__) + out_dict = { + dims_key: [ + get_flex_attention_choice_info(choice) + for choice in sorted_choices + ] + } + append_to_log(flex_attention_filename, out_dict) + best_time = timings[best] sys.stderr.write(f"AUTOTUNE {name}({sizes})\n") sys.stderr.write(f"strides: {strides}\n")