Add Full block support to flex_decoding (#131404)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131404
Approved by: https://github.com/yanboliang
This commit is contained in:
joydddd
2024-07-31 20:04:45 -07:00
committed by PyTorch MergeBot
parent 043e41f4f4
commit bdd83c4c7f
5 changed files with 261 additions and 176 deletions

View File

@ -37,6 +37,7 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) ->
class ExperimentConfig:
shape: Tuple[int]
score_mod: Callable
mask_mod: Callable
dtype: torch.dtype
calculate_bwd_time: bool
cal_bandwidth: bool
@ -122,7 +123,9 @@ def generate_inputs(
def run_single_experiment(
config: ExperimentConfig, dynamic=False, max_autotune=False, enable_mask=False
config: ExperimentConfig,
dynamic=False,
max_autotune=False,
) -> ExperimentResults:
device = torch.device("cuda")
batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape
@ -149,13 +152,14 @@ def run_single_experiment(
compiled_sdpa = torch.compile(flex_attention, dynamic=dynamic)
score_mod = config.score_mod
mask_mod = config.mask_mod
if enable_mask:
if mask_mod:
block_mask = create_block_mask(
score_mod, 1, 1, q_seq_len * (q_heads // kv_heads), kv_seq_len, query.device
mask_mod, 1, 1, q_seq_len * (q_heads // kv_heads), kv_seq_len, query.device
)
else:
block_mask = _create_empty_block_mask(query, key, value)
block_mask = _create_empty_block_mask(query, key)
forward_eager_time = benchmark_torch_function_in_microseconds(
eager_sdpa, query, key, value, score_mod
@ -328,7 +332,7 @@ def print_results(results: List[Experiment]):
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
def generate_score_mods(score_mods: List[str]) -> List[Callable]:
def generate_score_mods(score_mods: List[str]) -> List[Callable | None]:
def noop(score, b, h, m, n):
return score
@ -343,14 +347,33 @@ def generate_score_mods(score_mods: List[str]) -> List[Callable]:
function_dict = {
"noop": noop,
"causal": causal_mask,
"causal": None,
"rel": relative_bias,
"head_bias": head_bias,
}
return [function_dict[name] for name in score_mods]
def generate_mask_mods(score_mods: List[str]) -> List[Callable | None]:
def noop(b, h, m, n):
return True
def causal(b, h, m, n):
return m >= n
mask_mod_dict = {
"noop": None,
"causal": causal,
"rel": None,
"head_bias": None,
}
return [mask_mod_dict[name] for name in score_mods]
def get_gqa_score_mod(score_mod, G, q_seq_len):
if score_mod is None:
return None
def score_mod_gqa(score, b, hkv, m, n):
g = m // q_seq_len
new_m = m % q_seq_len
@ -362,6 +385,21 @@ def get_gqa_score_mod(score_mod, G, q_seq_len):
return score_mod_gqa
def get_gqa_mask_mod(mask_mod, G, q_seq_len):
if mask_mod is None:
return None
def mask_mod_gqa(b, h, m, n):
g = m // q_seq_len
new_m = m % q_seq_len
hq = h * G + g
return mask_mod(b, hq, new_m, n)
mask_mod_name = get_func_name(mask_mod)
set_func_name(mask_mod_gqa, mask_mod_name + "_gqa")
return mask_mod_gqa
def generate_experiment_configs(
calculate_bwd: bool,
dtype: torch.dtype,
@ -369,7 +407,7 @@ def generate_experiment_configs(
num_heads: List[Tuple[int, int]],
seq_lens: List[int],
head_dims: List[int],
score_mods: List[str],
score_mods_str: List[str],
decoding: bool,
kv_cache_size: List[int],
cal_bandwidth: bool,
@ -381,7 +419,8 @@ def generate_experiment_configs(
else:
q_kv_seq_lens = [(i, i) for i in seq_lens] # only testing q_len == kv_len
dtypes = [dtype]
score_mods = generate_score_mods(score_mods)
score_mods = generate_score_mods(score_mods_str)
mask_mods = generate_mask_mods(score_mods_str)
all_configs = []
for (
bsz,
@ -389,6 +428,7 @@ def generate_experiment_configs(
(q_seq_len, kv_seq_len),
head_dim,
score_mod,
mask_mod,
dtype,
) in itertools.product(
kv_cache_size if kv_cache_size else batch_sizes,
@ -396,6 +436,7 @@ def generate_experiment_configs(
q_kv_seq_lens,
head_dims,
score_mods,
mask_mods,
dtypes,
):
if kv_cache_size:
@ -410,11 +451,13 @@ def generate_experiment_configs(
assert q_heads % kv_heads == 0
G = q_heads // kv_heads
score_mod = get_gqa_score_mod(score_mod, G, q_seq_len)
mask_mod = get_gqa_mask_mod(mask_mod, G, q_seq_len)
all_configs.append(
ExperimentConfig(
shape=(bsz, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim),
score_mod=score_mod,
mask_mod=mask_mod,
dtype=dtype,
calculate_bwd_time=calculate_bwd,
cal_bandwidth=cal_bandwidth,
@ -450,7 +493,6 @@ def main(args):
config,
dynamic=args.dynamic,
max_autotune=args.max_autotune,
enable_mask=args.mask,
),
)
)
@ -526,9 +568,6 @@ Ignores -b batch size and calculate batch size from kv_cache size instead when s
action="store_true",
help="Calculate kernel memory bandwidth & computational throughput. ",
)
parser.add_argument(
"--mask", action="store_true", help="Enables block sparsity mask. "
)
# Parse arguments
args = parser.parse_args()