Clean up some tflop calc and add option for saving (#132799)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132799
Approved by: https://github.com/BoyuanFeng
This commit is contained in:
drisspg
2024-08-06 14:05:30 -07:00
committed by PyTorch MergeBot
parent cbee9c1fd2
commit cb4d1bfb71

View File

@ -1,4 +1,5 @@
import argparse
import csv
import itertools
from collections import defaultdict
from dataclasses import asdict, dataclass
@ -179,7 +180,7 @@ def run_single_experiment(
out_eager.backward, dOut, retain_graph=True
)
out_compile = compiled_sdpa(query, key, value, score_mod)
out_compile = compiled_sdpa(query, key, value, score_mod, block_mask)
dOut = torch.randn_like(out_eager)
backward_compile_time = benchmark_torch_function_in_microseconds(
out_compile.backward, dOut, retain_graph=True
@ -296,7 +297,7 @@ def get_average_speedups(results: List[Experiment], type: str):
return table_data
def print_results(results: List[Experiment]):
def print_results(results: List[Experiment], save_path: Optional[str] = None):
table_data = defaultdict(list)
for experiment in results:
for key, value in experiment.asdict().items():
@ -329,8 +330,8 @@ def print_results(results: List[Experiment]):
table_data["score_mod"] = [get_func_name(func) for func in table_data["score_mod"]]
table_data["mask_mod"] = [get_func_name(func) for func in table_data["mask_mod"]]
print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f"))
print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f"))
print("\n")
print("FWD Speedups".center(125, "="))
print("\n")
@ -344,6 +345,15 @@ def print_results(results: List[Experiment]):
average_data = get_average_speedups(results, type="bwd")
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
if save_path is not None:
with open(save_path, "w", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=table_data.keys())
writer.writeheader()
for i in range(len(next(iter(table_data.values())))):
row = {k: v[i] for k, v in table_data.items()}
writer.writerow(row)
print(f"\nResults saved to {save_path}")
def generate_score_mods(score_mods: List[str]) -> List[Callable | None]:
def noop(score, b, h, m, n):
@ -413,6 +423,77 @@ def get_gqa_mask_mod(mask_mod, G, q_seq_len):
return mask_mod_gqa
def generate_flash_configs(
calculate_bwd: bool,
dtype: torch.dtype,
batch_sizes: List[int],
num_heads: List[Tuple[int, int]],
seq_lens: List[int],
head_dims: List[int],
score_mods_str: List[str],
decoding: bool,
kv_cache_size: List[int],
cal_bandwidth: bool,
) -> List[ExperimentConfig]:
assert not (calculate_bwd and decoding), "Decoding does not support backward"
bs_seqlen_vals = [
(32, 512),
(16, 1024),
(8, 2048),
(4, 4096),
(2, 8192),
(1, 16384),
]
causal_vals = [False, True]
headdim_vals = [64, 128]
dim = 2048
score_mods = generate_score_mods(score_mods_str)
mask_mods = generate_mask_mods(score_mods_str)
all_configs = []
for (
(batch_size, seq_len),
causal,
head_dim,
score_mod,
mask_mod,
) in itertools.product(
bs_seqlen_vals,
causal_vals,
headdim_vals,
score_mods,
mask_mods,
):
num_heads = dim // head_dim
if decoding:
q_seq_len, kv_seq_len = 1, seq_len
else:
q_seq_len = kv_seq_len = seq_len
all_configs.append(
ExperimentConfig(
shape=(
batch_size,
num_heads,
q_seq_len,
num_heads,
kv_seq_len,
head_dim,
),
score_mod=score_mod,
mask_mod=mask_mod,
dtype=dtype,
calculate_bwd_time=calculate_bwd,
cal_bandwidth=cal_bandwidth,
)
)
return all_configs
def generate_experiment_configs(
calculate_bwd: bool,
dtype: torch.dtype,
@ -510,7 +591,7 @@ def main(args):
)
)
print_results(results)
print_results(results, args.save_path)
def heads_input_type(s):
@ -581,7 +662,12 @@ Ignores -b batch size and calculate batch size from kv_cache size instead when s
action="store_true",
help="Calculate kernel memory bandwidth & computational throughput. ",
)
parser.add_argument(
"--save-path",
type=str,
help="Path to save the results JSON file (optional)",
default=None,
)
# Parse arguments
args = parser.parse_args()
args.dtype = getattr(torch, args.dtype)