mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Add Lowering for FlexAttention Backwards (#125515)"
This reverts commit 95b9e981c3ab68fc17f78b8a6bbfd9569745ae4c.
Reverted https://github.com/pytorch/pytorch/pull/125515 on behalf of https://github.com/huydhn due to Sorry for reverting your change but the newly added test runs out of memory 95b9e981c3
([comment](https://github.com/pytorch/pytorch/pull/125515#issuecomment-2114084869))
This commit is contained in:
@ -3,7 +3,7 @@ import itertools
|
||||
from collections import defaultdict
|
||||
from dataclasses import asdict, dataclass
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -29,32 +29,28 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) ->
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExperimentConfig:
|
||||
shape: Tuple[int]
|
||||
batch_size: int
|
||||
num_heads: int
|
||||
q_seq_len: int
|
||||
k_seq_len: int
|
||||
head_dim: int
|
||||
score_mod: Callable
|
||||
dtype: torch.dtype
|
||||
calculate_bwd_time: bool
|
||||
|
||||
def __post_init__(self):
|
||||
assert len(self.shape) == 4, "Shape must be of length 4"
|
||||
|
||||
def asdict(self):
|
||||
# Convert the dataclass instance to a dictionary
|
||||
d = asdict(self)
|
||||
# Remove the 'calculate_bwd_time' key
|
||||
d.pop("calculate_bwd_time", None)
|
||||
return d
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Times:
|
||||
eager_time: float
|
||||
compiled_time: float
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExperimentResults:
|
||||
fwd_times: Times
|
||||
bwd_times: Optional[Times]
|
||||
eager_time: float
|
||||
compiled_time: float
|
||||
|
||||
def get_entries(self) -> List:
|
||||
return [
|
||||
f"{self.eager_time:2f}",
|
||||
f"{self.compiled_time:2f}",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -62,31 +58,29 @@ class Experiment:
|
||||
config: ExperimentConfig
|
||||
results: ExperimentResults
|
||||
|
||||
def get_entries(self) -> List:
|
||||
return self.config.get_entries() + self.results.get_entries()
|
||||
|
||||
def asdict(self):
|
||||
dict1 = self.config.asdict()
|
||||
dict1 = asdict(self.config)
|
||||
dict2 = asdict(self.results)
|
||||
return {**dict1, **dict2}
|
||||
|
||||
|
||||
def generate_inputs(
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
q_sequence_length: int,
|
||||
kv_sequence_length: int,
|
||||
head_dim: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
requires_grad: bool,
|
||||
batch_size,
|
||||
num_heads,
|
||||
q_sequence_length,
|
||||
kv_sequence_length,
|
||||
head_dim,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
q_shape = (batch_size, q_sequence_length, num_heads * head_dim)
|
||||
kv_shape = (batch_size, kv_sequence_length, num_heads * head_dim)
|
||||
|
||||
make_q = partial(
|
||||
torch.rand, q_shape, device=device, dtype=dtype, requires_grad=requires_grad
|
||||
)
|
||||
make_kv = partial(
|
||||
torch.rand, kv_shape, device=device, dtype=dtype, requires_grad=requires_grad
|
||||
)
|
||||
make_q = partial(torch.rand, q_shape, device=device, dtype=dtype)
|
||||
make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
|
||||
query = (
|
||||
make_q()
|
||||
.view(batch_size, q_sequence_length, num_heads, head_dim)
|
||||
@ -107,16 +101,14 @@ def generate_inputs(
|
||||
|
||||
def run_single_experiment(config: ExperimentConfig, dynamic=False) -> ExperimentResults:
|
||||
device = torch.device("cuda")
|
||||
batch_size, num_heads, q_seq_len, head_dim = config.shape
|
||||
query, key, value = generate_inputs(
|
||||
batch_size,
|
||||
num_heads,
|
||||
q_seq_len,
|
||||
q_seq_len,
|
||||
head_dim,
|
||||
config.batch_size,
|
||||
config.num_heads,
|
||||
config.q_seq_len,
|
||||
config.k_seq_len,
|
||||
config.head_dim,
|
||||
config.dtype,
|
||||
device,
|
||||
requires_grad=config.calculate_bwd_time,
|
||||
)
|
||||
|
||||
def eager_sdpa(query, key, value, _):
|
||||
@ -133,47 +125,23 @@ def run_single_experiment(config: ExperimentConfig, dynamic=False) -> Experiment
|
||||
compiled_sdpa, query, key, value, score_mod
|
||||
)
|
||||
|
||||
if config.calculate_bwd_time:
|
||||
out_eager = eager_sdpa(query, key, value, score_mod)
|
||||
dOut = torch.randn_like(out_eager)
|
||||
backward_eager_time = benchmark_torch_function_in_microseconds(
|
||||
out_eager.backward, dOut, retain_graph=True
|
||||
)
|
||||
|
||||
out_compile = compiled_sdpa(query, key, value, score_mod)
|
||||
dOut = torch.randn_like(out_eager)
|
||||
backward_compile_time = benchmark_torch_function_in_microseconds(
|
||||
out_compile.backward, dOut, retain_graph=True
|
||||
)
|
||||
|
||||
return ExperimentResults(
|
||||
fwd_times=Times(forward_eager_time, forward_compiled_time),
|
||||
bwd_times=Times(backward_eager_time, backward_compile_time),
|
||||
)
|
||||
else:
|
||||
return ExperimentResults(
|
||||
fwd_times=Times(forward_eager_time, forward_compiled_time),
|
||||
bwd_times=None,
|
||||
)
|
||||
return ExperimentResults(
|
||||
eager_time=forward_eager_time,
|
||||
compiled_time=forward_compiled_time,
|
||||
)
|
||||
|
||||
|
||||
def calculate_speedup(results: ExperimentResults, type: str) -> float:
|
||||
if type == "fwd":
|
||||
return results.fwd_times.eager_time / results.fwd_times.compiled_time
|
||||
elif type == "bwd":
|
||||
assert results.bwd_times is not None
|
||||
return results.bwd_times.eager_time / results.bwd_times.compiled_time
|
||||
else:
|
||||
raise ValueError(f"Invalid type {type}")
|
||||
def calculate_speedup(results: ExperimentResults) -> float:
|
||||
return results.eager_time / results.compiled_time
|
||||
|
||||
|
||||
def get_func_name(func):
|
||||
return func.__name__.split("<locals>.")[-1].split(" at ")[0]
|
||||
|
||||
|
||||
def get_average_speedups(results: List[Experiment], type: str):
|
||||
def get_average_speedups(results: List[Experiment]):
|
||||
# Calculate speedups
|
||||
speedups = [calculate_speedup(r.results, type) for r in results]
|
||||
speedups = [calculate_speedup(r.results) for r in results]
|
||||
|
||||
# Find indices of max and min speedups
|
||||
max_speedup_index = np.argmax(speedups)
|
||||
@ -209,39 +177,20 @@ def print_results(results: List[Experiment]):
|
||||
table_data = defaultdict(list)
|
||||
for experiment in results:
|
||||
for key, value in experiment.asdict().items():
|
||||
if key == "fwd_times":
|
||||
for name, time in value.items():
|
||||
table_data[f"fwd_{name}"].append(float(time))
|
||||
elif key == "bwd_times":
|
||||
if experiment.config.calculate_bwd_time:
|
||||
for name, time in value.items():
|
||||
table_data[f"bwd_{name}"].append(float(time))
|
||||
else:
|
||||
table_data[key].append(value)
|
||||
if key == "eager_time" or key == "compiled_time":
|
||||
value = float(value)
|
||||
table_data[key].append(value)
|
||||
|
||||
# Calculate speedups
|
||||
fwd_speedups = [calculate_speedup(r.results, type="fwd") for r in results]
|
||||
table_data["fwd_speedup"] = fwd_speedups
|
||||
if results[0].config.calculate_bwd_time:
|
||||
bwd_speedups = [calculate_speedup(r.results, type="bwd") for r in results]
|
||||
table_data["bwd_speedup"] = bwd_speedups
|
||||
speedups = [calculate_speedup(r.results) for r in results]
|
||||
table_data["speedup"] = speedups
|
||||
|
||||
table_data["score_mod"] = [get_func_name(func) for func in table_data["score_mod"]]
|
||||
print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
|
||||
print("\n")
|
||||
print("FWD Speedups".center(125, "="))
|
||||
print("\n")
|
||||
average_data = get_average_speedups(results, type="fwd")
|
||||
average_data = get_average_speedups(results)
|
||||
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
|
||||
if results[0].config.calculate_bwd_time:
|
||||
print("\n")
|
||||
print("BWD Speedups".center(125, "="))
|
||||
print("\n")
|
||||
average_data = get_average_speedups(results, type="bwd")
|
||||
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
|
||||
|
||||
def generate_score_mods() -> List[Callable]:
|
||||
def noop(score, b, h, m, n):
|
||||
@ -259,8 +208,8 @@ def generate_score_mods() -> List[Callable]:
|
||||
return [noop, causal_mask, relative_bias, head_bias]
|
||||
|
||||
|
||||
def generate_experiment_configs(calculate_bwd: bool) -> List[ExperimentConfig]:
|
||||
batch_sizes = [2, 8, 16]
|
||||
def generate_experiment_configs() -> List[ExperimentConfig]:
|
||||
batch_sizes = [1, 8, 16]
|
||||
num_heads = [16]
|
||||
q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)]
|
||||
head_dims = [64, 128, 256]
|
||||
@ -279,49 +228,41 @@ def generate_experiment_configs(calculate_bwd: bool) -> List[ExperimentConfig]:
|
||||
) in itertools.product(
|
||||
batch_sizes, num_heads, q_kv_seq_lens, head_dims, score_mods, dtypes
|
||||
):
|
||||
assert q_seq_len == kv_seq_len, "Only equal length inputs supported for now."
|
||||
all_configs.append(
|
||||
ExperimentConfig(
|
||||
shape=(bsz, n_heads, q_seq_len, head_dim),
|
||||
batch_size=bsz,
|
||||
num_heads=n_heads,
|
||||
q_seq_len=q_seq_len,
|
||||
k_seq_len=kv_seq_len,
|
||||
head_dim=head_dim,
|
||||
score_mod=score_mod,
|
||||
dtype=dtype,
|
||||
calculate_bwd_time=calculate_bwd,
|
||||
)
|
||||
)
|
||||
|
||||
return all_configs
|
||||
|
||||
|
||||
def main(dynamic: bool, calculate_bwd: bool):
|
||||
def main(dynamic=False):
|
||||
seed = 123
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
results = []
|
||||
for config in tqdm(generate_experiment_configs(calculate_bwd)):
|
||||
for config in tqdm(generate_experiment_configs()):
|
||||
results.append(
|
||||
Experiment(config, run_single_experiment(config, dynamic=dynamic))
|
||||
)
|
||||
for config in tqdm(generate_experiment_configs(calculate_bwd)):
|
||||
results.append(Experiment(config, run_single_experiment(config)))
|
||||
|
||||
print_results(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set up the argument parser
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run sweep over sizes and score mods for flex attention"
|
||||
)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--dynamic",
|
||||
action="store_true",
|
||||
help="Runs a dynamic shapes version of compiled flex attention.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--calculate-bwd", action="store_true", help="Calculate backward pass times"
|
||||
)
|
||||
|
||||
# Parse arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args.dynamic, args.calculate_bwd)
|
||||
main(args.dynamic)
|
||||
|
Reference in New Issue
Block a user