mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Clean up some tflop calc and add option for saving (#132799)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132799 Approved by: https://github.com/BoyuanFeng
This commit is contained in:
committed by
PyTorch MergeBot
parent
cbee9c1fd2
commit
cb4d1bfb71
@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import csv
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from dataclasses import asdict, dataclass
|
||||
@ -179,7 +180,7 @@ def run_single_experiment(
|
||||
out_eager.backward, dOut, retain_graph=True
|
||||
)
|
||||
|
||||
out_compile = compiled_sdpa(query, key, value, score_mod)
|
||||
out_compile = compiled_sdpa(query, key, value, score_mod, block_mask)
|
||||
dOut = torch.randn_like(out_eager)
|
||||
backward_compile_time = benchmark_torch_function_in_microseconds(
|
||||
out_compile.backward, dOut, retain_graph=True
|
||||
@ -296,7 +297,7 @@ def get_average_speedups(results: List[Experiment], type: str):
|
||||
return table_data
|
||||
|
||||
|
||||
def print_results(results: List[Experiment]):
|
||||
def print_results(results: List[Experiment], save_path: Optional[str] = None):
|
||||
table_data = defaultdict(list)
|
||||
for experiment in results:
|
||||
for key, value in experiment.asdict().items():
|
||||
@ -329,8 +330,8 @@ def print_results(results: List[Experiment]):
|
||||
|
||||
table_data["score_mod"] = [get_func_name(func) for func in table_data["score_mod"]]
|
||||
table_data["mask_mod"] = [get_func_name(func) for func in table_data["mask_mod"]]
|
||||
print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
|
||||
print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
print("\n")
|
||||
print("FWD Speedups".center(125, "="))
|
||||
print("\n")
|
||||
@ -344,6 +345,15 @@ def print_results(results: List[Experiment]):
|
||||
average_data = get_average_speedups(results, type="bwd")
|
||||
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
|
||||
if save_path is not None:
|
||||
with open(save_path, "w", newline="") as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=table_data.keys())
|
||||
writer.writeheader()
|
||||
for i in range(len(next(iter(table_data.values())))):
|
||||
row = {k: v[i] for k, v in table_data.items()}
|
||||
writer.writerow(row)
|
||||
print(f"\nResults saved to {save_path}")
|
||||
|
||||
|
||||
def generate_score_mods(score_mods: List[str]) -> List[Callable | None]:
|
||||
def noop(score, b, h, m, n):
|
||||
@ -413,6 +423,77 @@ def get_gqa_mask_mod(mask_mod, G, q_seq_len):
|
||||
return mask_mod_gqa
|
||||
|
||||
|
||||
def generate_flash_configs(
|
||||
calculate_bwd: bool,
|
||||
dtype: torch.dtype,
|
||||
batch_sizes: List[int],
|
||||
num_heads: List[Tuple[int, int]],
|
||||
seq_lens: List[int],
|
||||
head_dims: List[int],
|
||||
score_mods_str: List[str],
|
||||
decoding: bool,
|
||||
kv_cache_size: List[int],
|
||||
cal_bandwidth: bool,
|
||||
) -> List[ExperimentConfig]:
|
||||
assert not (calculate_bwd and decoding), "Decoding does not support backward"
|
||||
|
||||
bs_seqlen_vals = [
|
||||
(32, 512),
|
||||
(16, 1024),
|
||||
(8, 2048),
|
||||
(4, 4096),
|
||||
(2, 8192),
|
||||
(1, 16384),
|
||||
]
|
||||
causal_vals = [False, True]
|
||||
headdim_vals = [64, 128]
|
||||
dim = 2048
|
||||
|
||||
score_mods = generate_score_mods(score_mods_str)
|
||||
mask_mods = generate_mask_mods(score_mods_str)
|
||||
all_configs = []
|
||||
|
||||
for (
|
||||
(batch_size, seq_len),
|
||||
causal,
|
||||
head_dim,
|
||||
score_mod,
|
||||
mask_mod,
|
||||
) in itertools.product(
|
||||
bs_seqlen_vals,
|
||||
causal_vals,
|
||||
headdim_vals,
|
||||
score_mods,
|
||||
mask_mods,
|
||||
):
|
||||
num_heads = dim // head_dim
|
||||
|
||||
if decoding:
|
||||
q_seq_len, kv_seq_len = 1, seq_len
|
||||
else:
|
||||
q_seq_len = kv_seq_len = seq_len
|
||||
|
||||
all_configs.append(
|
||||
ExperimentConfig(
|
||||
shape=(
|
||||
batch_size,
|
||||
num_heads,
|
||||
q_seq_len,
|
||||
num_heads,
|
||||
kv_seq_len,
|
||||
head_dim,
|
||||
),
|
||||
score_mod=score_mod,
|
||||
mask_mod=mask_mod,
|
||||
dtype=dtype,
|
||||
calculate_bwd_time=calculate_bwd,
|
||||
cal_bandwidth=cal_bandwidth,
|
||||
)
|
||||
)
|
||||
|
||||
return all_configs
|
||||
|
||||
|
||||
def generate_experiment_configs(
|
||||
calculate_bwd: bool,
|
||||
dtype: torch.dtype,
|
||||
@ -510,7 +591,7 @@ def main(args):
|
||||
)
|
||||
)
|
||||
|
||||
print_results(results)
|
||||
print_results(results, args.save_path)
|
||||
|
||||
|
||||
def heads_input_type(s):
|
||||
@ -581,7 +662,12 @@ Ignores -b batch size and calculate batch size from kv_cache size instead when s
|
||||
action="store_true",
|
||||
help="Calculate kernel memory bandwidth & computational throughput. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
help="Path to save the results JSON file (optional)",
|
||||
default=None,
|
||||
)
|
||||
# Parse arguments
|
||||
args = parser.parse_args()
|
||||
args.dtype = getattr(torch, args.dtype)
|
||||
|
Reference in New Issue
Block a user