Revert "Add Lowering for FlexAttention Backwards (#125515)"

This reverts commit 95b9e981c3ab68fc17f78b8a6bbfd9569745ae4c.

Reverted https://github.com/pytorch/pytorch/pull/125515 on behalf of https://github.com/huydhn due to Sorry for reverting your change but the newly added test runs out of memory 95b9e981c3 ([comment](https://github.com/pytorch/pytorch/pull/125515#issuecomment-2114084869))
This commit is contained in:
PyTorch MergeBot
2024-05-16 05:52:13 +00:00
parent cdcba4dee5
commit 0716f75cfb
9 changed files with 309 additions and 828 deletions

View File

@ -3,7 +3,7 @@ import itertools
from collections import defaultdict
from dataclasses import asdict, dataclass
from functools import partial
from typing import Callable, List, Optional, Tuple
from typing import Callable, List
import numpy as np
import torch
@ -29,32 +29,28 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) ->
@dataclass(frozen=True)
class ExperimentConfig:
shape: Tuple[int]
batch_size: int
num_heads: int
q_seq_len: int
k_seq_len: int
head_dim: int
score_mod: Callable
dtype: torch.dtype
calculate_bwd_time: bool
def __post_init__(self):
assert len(self.shape) == 4, "Shape must be of length 4"
def asdict(self):
# Convert the dataclass instance to a dictionary
d = asdict(self)
# Remove the 'calculate_bwd_time' key
d.pop("calculate_bwd_time", None)
return d
@dataclass(frozen=True)
class Times:
eager_time: float
compiled_time: float
return asdict(self)
@dataclass(frozen=True)
class ExperimentResults:
fwd_times: Times
bwd_times: Optional[Times]
eager_time: float
compiled_time: float
def get_entries(self) -> List:
return [
f"{self.eager_time:2f}",
f"{self.compiled_time:2f}",
]
@dataclass(frozen=True)
@ -62,31 +58,29 @@ class Experiment:
config: ExperimentConfig
results: ExperimentResults
def get_entries(self) -> List:
return self.config.get_entries() + self.results.get_entries()
def asdict(self):
dict1 = self.config.asdict()
dict1 = asdict(self.config)
dict2 = asdict(self.results)
return {**dict1, **dict2}
def generate_inputs(
batch_size: int,
num_heads: int,
q_sequence_length: int,
kv_sequence_length: int,
head_dim: int,
dtype: torch.dtype,
device: torch.device,
requires_grad: bool,
batch_size,
num_heads,
q_sequence_length,
kv_sequence_length,
head_dim,
dtype,
device,
):
q_shape = (batch_size, q_sequence_length, num_heads * head_dim)
kv_shape = (batch_size, kv_sequence_length, num_heads * head_dim)
make_q = partial(
torch.rand, q_shape, device=device, dtype=dtype, requires_grad=requires_grad
)
make_kv = partial(
torch.rand, kv_shape, device=device, dtype=dtype, requires_grad=requires_grad
)
make_q = partial(torch.rand, q_shape, device=device, dtype=dtype)
make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
query = (
make_q()
.view(batch_size, q_sequence_length, num_heads, head_dim)
@ -107,16 +101,14 @@ def generate_inputs(
def run_single_experiment(config: ExperimentConfig, dynamic=False) -> ExperimentResults:
device = torch.device("cuda")
batch_size, num_heads, q_seq_len, head_dim = config.shape
query, key, value = generate_inputs(
batch_size,
num_heads,
q_seq_len,
q_seq_len,
head_dim,
config.batch_size,
config.num_heads,
config.q_seq_len,
config.k_seq_len,
config.head_dim,
config.dtype,
device,
requires_grad=config.calculate_bwd_time,
)
def eager_sdpa(query, key, value, _):
@ -133,47 +125,23 @@ def run_single_experiment(config: ExperimentConfig, dynamic=False) -> Experiment
compiled_sdpa, query, key, value, score_mod
)
if config.calculate_bwd_time:
out_eager = eager_sdpa(query, key, value, score_mod)
dOut = torch.randn_like(out_eager)
backward_eager_time = benchmark_torch_function_in_microseconds(
out_eager.backward, dOut, retain_graph=True
)
out_compile = compiled_sdpa(query, key, value, score_mod)
dOut = torch.randn_like(out_eager)
backward_compile_time = benchmark_torch_function_in_microseconds(
out_compile.backward, dOut, retain_graph=True
)
return ExperimentResults(
fwd_times=Times(forward_eager_time, forward_compiled_time),
bwd_times=Times(backward_eager_time, backward_compile_time),
)
else:
return ExperimentResults(
fwd_times=Times(forward_eager_time, forward_compiled_time),
bwd_times=None,
)
return ExperimentResults(
eager_time=forward_eager_time,
compiled_time=forward_compiled_time,
)
def calculate_speedup(results: ExperimentResults, type: str) -> float:
if type == "fwd":
return results.fwd_times.eager_time / results.fwd_times.compiled_time
elif type == "bwd":
assert results.bwd_times is not None
return results.bwd_times.eager_time / results.bwd_times.compiled_time
else:
raise ValueError(f"Invalid type {type}")
def calculate_speedup(results: ExperimentResults) -> float:
return results.eager_time / results.compiled_time
def get_func_name(func):
return func.__name__.split("<locals>.")[-1].split(" at ")[0]
def get_average_speedups(results: List[Experiment], type: str):
def get_average_speedups(results: List[Experiment]):
# Calculate speedups
speedups = [calculate_speedup(r.results, type) for r in results]
speedups = [calculate_speedup(r.results) for r in results]
# Find indices of max and min speedups
max_speedup_index = np.argmax(speedups)
@ -209,39 +177,20 @@ def print_results(results: List[Experiment]):
table_data = defaultdict(list)
for experiment in results:
for key, value in experiment.asdict().items():
if key == "fwd_times":
for name, time in value.items():
table_data[f"fwd_{name}"].append(float(time))
elif key == "bwd_times":
if experiment.config.calculate_bwd_time:
for name, time in value.items():
table_data[f"bwd_{name}"].append(float(time))
else:
table_data[key].append(value)
if key == "eager_time" or key == "compiled_time":
value = float(value)
table_data[key].append(value)
# Calculate speedups
fwd_speedups = [calculate_speedup(r.results, type="fwd") for r in results]
table_data["fwd_speedup"] = fwd_speedups
if results[0].config.calculate_bwd_time:
bwd_speedups = [calculate_speedup(r.results, type="bwd") for r in results]
table_data["bwd_speedup"] = bwd_speedups
speedups = [calculate_speedup(r.results) for r in results]
table_data["speedup"] = speedups
table_data["score_mod"] = [get_func_name(func) for func in table_data["score_mod"]]
print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f"))
print("\n")
print("FWD Speedups".center(125, "="))
print("\n")
average_data = get_average_speedups(results, type="fwd")
average_data = get_average_speedups(results)
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
if results[0].config.calculate_bwd_time:
print("\n")
print("BWD Speedups".center(125, "="))
print("\n")
average_data = get_average_speedups(results, type="bwd")
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
def generate_score_mods() -> List[Callable]:
def noop(score, b, h, m, n):
@ -259,8 +208,8 @@ def generate_score_mods() -> List[Callable]:
return [noop, causal_mask, relative_bias, head_bias]
def generate_experiment_configs(calculate_bwd: bool) -> List[ExperimentConfig]:
batch_sizes = [2, 8, 16]
def generate_experiment_configs() -> List[ExperimentConfig]:
batch_sizes = [1, 8, 16]
num_heads = [16]
q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)]
head_dims = [64, 128, 256]
@ -279,49 +228,41 @@ def generate_experiment_configs(calculate_bwd: bool) -> List[ExperimentConfig]:
) in itertools.product(
batch_sizes, num_heads, q_kv_seq_lens, head_dims, score_mods, dtypes
):
assert q_seq_len == kv_seq_len, "Only equal length inputs supported for now."
all_configs.append(
ExperimentConfig(
shape=(bsz, n_heads, q_seq_len, head_dim),
batch_size=bsz,
num_heads=n_heads,
q_seq_len=q_seq_len,
k_seq_len=kv_seq_len,
head_dim=head_dim,
score_mod=score_mod,
dtype=dtype,
calculate_bwd_time=calculate_bwd,
)
)
return all_configs
def main(dynamic: bool, calculate_bwd: bool):
def main(dynamic=False):
seed = 123
np.random.seed(seed)
torch.manual_seed(seed)
results = []
for config in tqdm(generate_experiment_configs(calculate_bwd)):
for config in tqdm(generate_experiment_configs()):
results.append(
Experiment(config, run_single_experiment(config, dynamic=dynamic))
)
for config in tqdm(generate_experiment_configs(calculate_bwd)):
results.append(Experiment(config, run_single_experiment(config)))
print_results(results)
if __name__ == "__main__":
# Set up the argument parser
parser = argparse.ArgumentParser(
description="Run sweep over sizes and score mods for flex attention"
)
parser = argparse.ArgumentParser()
parser.add_argument(
"--dynamic",
action="store_true",
help="Runs a dynamic shapes version of compiled flex attention.",
)
parser.add_argument(
"--calculate-bwd", action="store_true", help="Calculate backward pass times"
)
# Parse arguments
args = parser.parse_args()
main(args.dynamic, args.calculate_bwd)
main(args.dynamic)