mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Clean up some tflop calc and add option for saving (#132799)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132799 Approved by: https://github.com/BoyuanFeng
This commit is contained in:
committed by
PyTorch MergeBot
parent
cbee9c1fd2
commit
cb4d1bfb71
@ -1,4 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import csv
|
||||||
import itertools
|
import itertools
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
@ -179,7 +180,7 @@ def run_single_experiment(
|
|||||||
out_eager.backward, dOut, retain_graph=True
|
out_eager.backward, dOut, retain_graph=True
|
||||||
)
|
)
|
||||||
|
|
||||||
out_compile = compiled_sdpa(query, key, value, score_mod)
|
out_compile = compiled_sdpa(query, key, value, score_mod, block_mask)
|
||||||
dOut = torch.randn_like(out_eager)
|
dOut = torch.randn_like(out_eager)
|
||||||
backward_compile_time = benchmark_torch_function_in_microseconds(
|
backward_compile_time = benchmark_torch_function_in_microseconds(
|
||||||
out_compile.backward, dOut, retain_graph=True
|
out_compile.backward, dOut, retain_graph=True
|
||||||
@ -296,7 +297,7 @@ def get_average_speedups(results: List[Experiment], type: str):
|
|||||||
return table_data
|
return table_data
|
||||||
|
|
||||||
|
|
||||||
def print_results(results: List[Experiment]):
|
def print_results(results: List[Experiment], save_path: Optional[str] = None):
|
||||||
table_data = defaultdict(list)
|
table_data = defaultdict(list)
|
||||||
for experiment in results:
|
for experiment in results:
|
||||||
for key, value in experiment.asdict().items():
|
for key, value in experiment.asdict().items():
|
||||||
@ -329,8 +330,8 @@ def print_results(results: List[Experiment]):
|
|||||||
|
|
||||||
table_data["score_mod"] = [get_func_name(func) for func in table_data["score_mod"]]
|
table_data["score_mod"] = [get_func_name(func) for func in table_data["score_mod"]]
|
||||||
table_data["mask_mod"] = [get_func_name(func) for func in table_data["mask_mod"]]
|
table_data["mask_mod"] = [get_func_name(func) for func in table_data["mask_mod"]]
|
||||||
print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
|
||||||
|
|
||||||
|
print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||||
print("\n")
|
print("\n")
|
||||||
print("FWD Speedups".center(125, "="))
|
print("FWD Speedups".center(125, "="))
|
||||||
print("\n")
|
print("\n")
|
||||||
@ -344,6 +345,15 @@ def print_results(results: List[Experiment]):
|
|||||||
average_data = get_average_speedups(results, type="bwd")
|
average_data = get_average_speedups(results, type="bwd")
|
||||||
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||||
|
|
||||||
|
if save_path is not None:
|
||||||
|
with open(save_path, "w", newline="") as csvfile:
|
||||||
|
writer = csv.DictWriter(csvfile, fieldnames=table_data.keys())
|
||||||
|
writer.writeheader()
|
||||||
|
for i in range(len(next(iter(table_data.values())))):
|
||||||
|
row = {k: v[i] for k, v in table_data.items()}
|
||||||
|
writer.writerow(row)
|
||||||
|
print(f"\nResults saved to {save_path}")
|
||||||
|
|
||||||
|
|
||||||
def generate_score_mods(score_mods: List[str]) -> List[Callable | None]:
|
def generate_score_mods(score_mods: List[str]) -> List[Callable | None]:
|
||||||
def noop(score, b, h, m, n):
|
def noop(score, b, h, m, n):
|
||||||
@ -413,6 +423,77 @@ def get_gqa_mask_mod(mask_mod, G, q_seq_len):
|
|||||||
return mask_mod_gqa
|
return mask_mod_gqa
|
||||||
|
|
||||||
|
|
||||||
|
def generate_flash_configs(
|
||||||
|
calculate_bwd: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
batch_sizes: List[int],
|
||||||
|
num_heads: List[Tuple[int, int]],
|
||||||
|
seq_lens: List[int],
|
||||||
|
head_dims: List[int],
|
||||||
|
score_mods_str: List[str],
|
||||||
|
decoding: bool,
|
||||||
|
kv_cache_size: List[int],
|
||||||
|
cal_bandwidth: bool,
|
||||||
|
) -> List[ExperimentConfig]:
|
||||||
|
assert not (calculate_bwd and decoding), "Decoding does not support backward"
|
||||||
|
|
||||||
|
bs_seqlen_vals = [
|
||||||
|
(32, 512),
|
||||||
|
(16, 1024),
|
||||||
|
(8, 2048),
|
||||||
|
(4, 4096),
|
||||||
|
(2, 8192),
|
||||||
|
(1, 16384),
|
||||||
|
]
|
||||||
|
causal_vals = [False, True]
|
||||||
|
headdim_vals = [64, 128]
|
||||||
|
dim = 2048
|
||||||
|
|
||||||
|
score_mods = generate_score_mods(score_mods_str)
|
||||||
|
mask_mods = generate_mask_mods(score_mods_str)
|
||||||
|
all_configs = []
|
||||||
|
|
||||||
|
for (
|
||||||
|
(batch_size, seq_len),
|
||||||
|
causal,
|
||||||
|
head_dim,
|
||||||
|
score_mod,
|
||||||
|
mask_mod,
|
||||||
|
) in itertools.product(
|
||||||
|
bs_seqlen_vals,
|
||||||
|
causal_vals,
|
||||||
|
headdim_vals,
|
||||||
|
score_mods,
|
||||||
|
mask_mods,
|
||||||
|
):
|
||||||
|
num_heads = dim // head_dim
|
||||||
|
|
||||||
|
if decoding:
|
||||||
|
q_seq_len, kv_seq_len = 1, seq_len
|
||||||
|
else:
|
||||||
|
q_seq_len = kv_seq_len = seq_len
|
||||||
|
|
||||||
|
all_configs.append(
|
||||||
|
ExperimentConfig(
|
||||||
|
shape=(
|
||||||
|
batch_size,
|
||||||
|
num_heads,
|
||||||
|
q_seq_len,
|
||||||
|
num_heads,
|
||||||
|
kv_seq_len,
|
||||||
|
head_dim,
|
||||||
|
),
|
||||||
|
score_mod=score_mod,
|
||||||
|
mask_mod=mask_mod,
|
||||||
|
dtype=dtype,
|
||||||
|
calculate_bwd_time=calculate_bwd,
|
||||||
|
cal_bandwidth=cal_bandwidth,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_configs
|
||||||
|
|
||||||
|
|
||||||
def generate_experiment_configs(
|
def generate_experiment_configs(
|
||||||
calculate_bwd: bool,
|
calculate_bwd: bool,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
@ -510,7 +591,7 @@ def main(args):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
print_results(results)
|
print_results(results, args.save_path)
|
||||||
|
|
||||||
|
|
||||||
def heads_input_type(s):
|
def heads_input_type(s):
|
||||||
@ -581,7 +662,12 @@ Ignores -b batch size and calculate batch size from kv_cache size instead when s
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Calculate kernel memory bandwidth & computational throughput. ",
|
help="Calculate kernel memory bandwidth & computational throughput. ",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-path",
|
||||||
|
type=str,
|
||||||
|
help="Path to save the results JSON file (optional)",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
# Parse arguments
|
# Parse arguments
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.dtype = getattr(torch, args.dtype)
|
args.dtype = getattr(torch, args.dtype)
|
||||||
|
Reference in New Issue
Block a user