mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Full block support to flex_decoding (#131404)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131404 Approved by: https://github.com/yanboliang
This commit is contained in:
committed by
PyTorch MergeBot
parent
043e41f4f4
commit
bdd83c4c7f
@ -37,6 +37,7 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) ->
|
||||
class ExperimentConfig:
|
||||
shape: Tuple[int]
|
||||
score_mod: Callable
|
||||
mask_mod: Callable
|
||||
dtype: torch.dtype
|
||||
calculate_bwd_time: bool
|
||||
cal_bandwidth: bool
|
||||
@ -122,7 +123,9 @@ def generate_inputs(
|
||||
|
||||
|
||||
def run_single_experiment(
|
||||
config: ExperimentConfig, dynamic=False, max_autotune=False, enable_mask=False
|
||||
config: ExperimentConfig,
|
||||
dynamic=False,
|
||||
max_autotune=False,
|
||||
) -> ExperimentResults:
|
||||
device = torch.device("cuda")
|
||||
batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape
|
||||
@ -149,13 +152,14 @@ def run_single_experiment(
|
||||
compiled_sdpa = torch.compile(flex_attention, dynamic=dynamic)
|
||||
|
||||
score_mod = config.score_mod
|
||||
mask_mod = config.mask_mod
|
||||
|
||||
if enable_mask:
|
||||
if mask_mod:
|
||||
block_mask = create_block_mask(
|
||||
score_mod, 1, 1, q_seq_len * (q_heads // kv_heads), kv_seq_len, query.device
|
||||
mask_mod, 1, 1, q_seq_len * (q_heads // kv_heads), kv_seq_len, query.device
|
||||
)
|
||||
else:
|
||||
block_mask = _create_empty_block_mask(query, key, value)
|
||||
block_mask = _create_empty_block_mask(query, key)
|
||||
|
||||
forward_eager_time = benchmark_torch_function_in_microseconds(
|
||||
eager_sdpa, query, key, value, score_mod
|
||||
@ -328,7 +332,7 @@ def print_results(results: List[Experiment]):
|
||||
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
|
||||
|
||||
def generate_score_mods(score_mods: List[str]) -> List[Callable]:
|
||||
def generate_score_mods(score_mods: List[str]) -> List[Callable | None]:
|
||||
def noop(score, b, h, m, n):
|
||||
return score
|
||||
|
||||
@ -343,14 +347,33 @@ def generate_score_mods(score_mods: List[str]) -> List[Callable]:
|
||||
|
||||
function_dict = {
|
||||
"noop": noop,
|
||||
"causal": causal_mask,
|
||||
"causal": None,
|
||||
"rel": relative_bias,
|
||||
"head_bias": head_bias,
|
||||
}
|
||||
return [function_dict[name] for name in score_mods]
|
||||
|
||||
|
||||
def generate_mask_mods(score_mods: List[str]) -> List[Callable | None]:
|
||||
def noop(b, h, m, n):
|
||||
return True
|
||||
|
||||
def causal(b, h, m, n):
|
||||
return m >= n
|
||||
|
||||
mask_mod_dict = {
|
||||
"noop": None,
|
||||
"causal": causal,
|
||||
"rel": None,
|
||||
"head_bias": None,
|
||||
}
|
||||
return [mask_mod_dict[name] for name in score_mods]
|
||||
|
||||
|
||||
def get_gqa_score_mod(score_mod, G, q_seq_len):
|
||||
if score_mod is None:
|
||||
return None
|
||||
|
||||
def score_mod_gqa(score, b, hkv, m, n):
|
||||
g = m // q_seq_len
|
||||
new_m = m % q_seq_len
|
||||
@ -362,6 +385,21 @@ def get_gqa_score_mod(score_mod, G, q_seq_len):
|
||||
return score_mod_gqa
|
||||
|
||||
|
||||
def get_gqa_mask_mod(mask_mod, G, q_seq_len):
|
||||
if mask_mod is None:
|
||||
return None
|
||||
|
||||
def mask_mod_gqa(b, h, m, n):
|
||||
g = m // q_seq_len
|
||||
new_m = m % q_seq_len
|
||||
hq = h * G + g
|
||||
return mask_mod(b, hq, new_m, n)
|
||||
|
||||
mask_mod_name = get_func_name(mask_mod)
|
||||
set_func_name(mask_mod_gqa, mask_mod_name + "_gqa")
|
||||
return mask_mod_gqa
|
||||
|
||||
|
||||
def generate_experiment_configs(
|
||||
calculate_bwd: bool,
|
||||
dtype: torch.dtype,
|
||||
@ -369,7 +407,7 @@ def generate_experiment_configs(
|
||||
num_heads: List[Tuple[int, int]],
|
||||
seq_lens: List[int],
|
||||
head_dims: List[int],
|
||||
score_mods: List[str],
|
||||
score_mods_str: List[str],
|
||||
decoding: bool,
|
||||
kv_cache_size: List[int],
|
||||
cal_bandwidth: bool,
|
||||
@ -381,7 +419,8 @@ def generate_experiment_configs(
|
||||
else:
|
||||
q_kv_seq_lens = [(i, i) for i in seq_lens] # only testing q_len == kv_len
|
||||
dtypes = [dtype]
|
||||
score_mods = generate_score_mods(score_mods)
|
||||
score_mods = generate_score_mods(score_mods_str)
|
||||
mask_mods = generate_mask_mods(score_mods_str)
|
||||
all_configs = []
|
||||
for (
|
||||
bsz,
|
||||
@ -389,6 +428,7 @@ def generate_experiment_configs(
|
||||
(q_seq_len, kv_seq_len),
|
||||
head_dim,
|
||||
score_mod,
|
||||
mask_mod,
|
||||
dtype,
|
||||
) in itertools.product(
|
||||
kv_cache_size if kv_cache_size else batch_sizes,
|
||||
@ -396,6 +436,7 @@ def generate_experiment_configs(
|
||||
q_kv_seq_lens,
|
||||
head_dims,
|
||||
score_mods,
|
||||
mask_mods,
|
||||
dtypes,
|
||||
):
|
||||
if kv_cache_size:
|
||||
@ -410,11 +451,13 @@ def generate_experiment_configs(
|
||||
assert q_heads % kv_heads == 0
|
||||
G = q_heads // kv_heads
|
||||
score_mod = get_gqa_score_mod(score_mod, G, q_seq_len)
|
||||
mask_mod = get_gqa_mask_mod(mask_mod, G, q_seq_len)
|
||||
|
||||
all_configs.append(
|
||||
ExperimentConfig(
|
||||
shape=(bsz, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim),
|
||||
score_mod=score_mod,
|
||||
mask_mod=mask_mod,
|
||||
dtype=dtype,
|
||||
calculate_bwd_time=calculate_bwd,
|
||||
cal_bandwidth=cal_bandwidth,
|
||||
@ -450,7 +493,6 @@ def main(args):
|
||||
config,
|
||||
dynamic=args.dynamic,
|
||||
max_autotune=args.max_autotune,
|
||||
enable_mask=args.mask,
|
||||
),
|
||||
)
|
||||
)
|
||||
@ -526,9 +568,6 @@ Ignores -b batch size and calculate batch size from kv_cache size instead when s
|
||||
action="store_true",
|
||||
help="Calculate kernel memory bandwidth & computational throughput. ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask", action="store_true", help="Enables block sparsity mask. "
|
||||
)
|
||||
|
||||
# Parse arguments
|
||||
args = parser.parse_args()
|
||||
|
Reference in New Issue
Block a user