mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Improved flexattention bwd perf + added configurations for benchmarks (#129013)"
This reverts commit ff89ebc50a738c734496393dc25313cf197fd0b4.
Reverted https://github.com/pytorch/pytorch/pull/129013 on behalf of https://github.com/huydhn due to Sorry for reverting your change but one of the test_torchinductor_opinfo test starts to fail after this commit ff89ebc50a
, I am reverting to see if it helps trunk recovers ([comment](https://github.com/pytorch/pytorch/pull/129013#issuecomment-2182042422))
This commit is contained in:
@ -106,9 +106,7 @@ def generate_inputs(
|
||||
return query, key, value
|
||||
|
||||
|
||||
def run_single_experiment(
|
||||
config: ExperimentConfig, dynamic=False, max_autotune=False
|
||||
) -> ExperimentResults:
|
||||
def run_single_experiment(config: ExperimentConfig, dynamic=False) -> ExperimentResults:
|
||||
device = torch.device("cuda")
|
||||
batch_size, num_heads, q_seq_len, head_dim = config.shape
|
||||
query, key, value = generate_inputs(
|
||||
@ -125,12 +123,7 @@ def run_single_experiment(
|
||||
def eager_sdpa(query, key, value, _):
|
||||
return F.scaled_dot_product_attention(query, key, value)
|
||||
|
||||
if max_autotune:
|
||||
compiled_sdpa = torch.compile(
|
||||
_flex_attention, dynamic=dynamic, mode="max-autotune-no-cudagraphs"
|
||||
)
|
||||
else:
|
||||
compiled_sdpa = torch.compile(_flex_attention, dynamic=dynamic)
|
||||
compiled_sdpa = torch.compile(_flex_attention, dynamic=dynamic)
|
||||
|
||||
score_mod = config.score_mod
|
||||
|
||||
@ -251,7 +244,7 @@ def print_results(results: List[Experiment]):
|
||||
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
|
||||
|
||||
def generate_score_mods(score_mods: List[str]) -> List[Callable]:
|
||||
def generate_score_mods() -> List[Callable]:
|
||||
def noop(score, b, h, m, n):
|
||||
return score
|
||||
|
||||
@ -264,27 +257,18 @@ def generate_score_mods(score_mods: List[str]) -> List[Callable]:
|
||||
def head_bias(score, b, h, m, n):
|
||||
return score + 2 * h
|
||||
|
||||
function_dict = {
|
||||
"noop": noop,
|
||||
"causal": causal_mask,
|
||||
"rel": relative_bias,
|
||||
"head_bias": head_bias,
|
||||
}
|
||||
return [function_dict[name] for name in score_mods]
|
||||
return [noop, causal_mask, relative_bias, head_bias]
|
||||
|
||||
|
||||
def generate_experiment_configs(
|
||||
calculate_bwd: bool,
|
||||
dtype: torch.dtype,
|
||||
batch_sizes: List[int],
|
||||
num_heads: List[int],
|
||||
seq_lens: List[int],
|
||||
head_dims: List[int],
|
||||
score_mods: List[str],
|
||||
) -> List[ExperimentConfig]:
|
||||
q_kv_seq_lens = [(i, i) for i in seq_lens] # only testing q_len == kv_len
|
||||
dtypes = [dtype]
|
||||
score_mods = generate_score_mods(score_mods)
|
||||
def generate_experiment_configs(calculate_bwd: bool) -> List[ExperimentConfig]:
|
||||
batch_sizes = [2, 8, 16]
|
||||
num_heads = [16]
|
||||
q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)]
|
||||
head_dims = [64, 128]
|
||||
dtypes = [
|
||||
torch.bfloat16,
|
||||
]
|
||||
score_mods = generate_score_mods()
|
||||
all_configs = []
|
||||
for (
|
||||
bsz,
|
||||
@ -309,23 +293,14 @@ def generate_experiment_configs(
|
||||
return all_configs
|
||||
|
||||
|
||||
def main(args):
|
||||
def main(dynamic: bool, calculate_bwd: bool):
|
||||
seed = 123
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
results = []
|
||||
for config in tqdm(
|
||||
generate_experiment_configs(
|
||||
args.calculate_bwd, args.dtype, args.b, args.nh, args.s, args.d, args.mods
|
||||
)
|
||||
):
|
||||
for config in tqdm(generate_experiment_configs(calculate_bwd)):
|
||||
results.append(
|
||||
Experiment(
|
||||
config,
|
||||
run_single_experiment(
|
||||
config, dynamic=args.dynamic, max_autotune=args.max_autotune
|
||||
),
|
||||
)
|
||||
Experiment(config, run_single_experiment(config, dynamic=dynamic))
|
||||
)
|
||||
|
||||
print_results(results)
|
||||
@ -345,29 +320,7 @@ if __name__ == "__main__":
|
||||
"--calculate-bwd", action="store_true", help="Calculate backward pass times"
|
||||
)
|
||||
|
||||
parser.add_argument("-dtype", type=str, help="dtype", default="bfloat16")
|
||||
|
||||
parser.add_argument(
|
||||
"-b", type=int, nargs="+", help="batch sizes", default=[2, 8, 16]
|
||||
)
|
||||
parser.add_argument("-nh", type=int, nargs="+", help="# of heads", default=[16])
|
||||
parser.add_argument(
|
||||
"-s", type=int, nargs="+", help="sequence lengths", default=[512, 1024, 4096]
|
||||
)
|
||||
parser.add_argument("-d", type=int, nargs="+", help="head dims", default=[64, 128])
|
||||
parser.add_argument(
|
||||
"-mods",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="score mods",
|
||||
default=["noop", "causal", "rel", "head_bias"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-autotune", action="store_true", help="Turn on max-autotune"
|
||||
)
|
||||
|
||||
# Parse arguments
|
||||
args = parser.parse_args()
|
||||
args.dtype = getattr(torch, args.dtype)
|
||||
|
||||
main(args)
|
||||
main(args.dynamic, args.calculate_bwd)
|
||||
|
@ -32,7 +32,6 @@ from torch.utils._triton import has_triton
|
||||
# Skip tests if Triton is not available
|
||||
supported_platform = skipUnless(
|
||||
torch.cuda.is_available()
|
||||
and torch.version.hip is None
|
||||
and has_triton()
|
||||
and torch.cuda.get_device_capability() >= (8, 0),
|
||||
"Requires CUDA and Triton",
|
||||
|
@ -277,7 +277,7 @@ def flex_attention_fake_tensor_mode(
|
||||
logsumexp = query.new_empty(
|
||||
batch_size, num_heads, seq_len_q, dtype=torch.float32
|
||||
)
|
||||
return torch.empty_like(query), logsumexp
|
||||
return torch.empty_like(query, memory_format=torch.contiguous_format), logsumexp
|
||||
|
||||
|
||||
# ---------------------------- Autograd Implementation ----------------------------
|
||||
@ -670,9 +670,9 @@ def flex_attention_backward_fake_tensor_mode(
|
||||
*other_buffers: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
with mode:
|
||||
grad_query = torch.empty_like(query)
|
||||
grad_key = torch.empty_like(key)
|
||||
grad_value = torch.empty_like(value)
|
||||
grad_query = torch.empty_like(query, memory_format=torch.contiguous_format)
|
||||
grad_key = torch.empty_like(key, memory_format=torch.contiguous_format)
|
||||
grad_value = torch.empty_like(value, memory_format=torch.contiguous_format)
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
|
@ -17,7 +17,7 @@ from ..ir import (
|
||||
Subgraph,
|
||||
TensorBox,
|
||||
)
|
||||
from ..lowering import empty_strided, lowerings, register_lowering
|
||||
from ..lowering import empty_strided, full, lowerings, register_lowering
|
||||
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -314,9 +314,6 @@ _h100_default_config = {
|
||||
(torch.bfloat16, 64): (128, 64, 4, 3),
|
||||
(torch.bfloat16, 128): (64, 32, 4, 3),
|
||||
(torch.bfloat16, 256): (64, 32, 4, 3),
|
||||
(torch.float16, 64): (128, 64, 4, 3),
|
||||
(torch.float16, 128): (64, 32, 4, 3),
|
||||
(torch.float16, 256): (64, 32, 4, 3),
|
||||
}
|
||||
|
||||
_a100_default_config = {
|
||||
@ -324,11 +321,8 @@ _a100_default_config = {
|
||||
(torch.float32, 128): (128, 32, 4, 3),
|
||||
(torch.float32, 256): (64, 16, 4, 3),
|
||||
(torch.bfloat16, 64): (128, 64, 4, 3),
|
||||
(torch.bfloat16, 128): (128, 128, 8, 2),
|
||||
(torch.bfloat16, 128): (128, 32, 4, 3),
|
||||
(torch.bfloat16, 256): (32, 64, 4, 3),
|
||||
(torch.float16, 64): (128, 64, 4, 3),
|
||||
(torch.float16, 128): (128, 128, 8, 2),
|
||||
(torch.float16, 256): (32, 64, 4, 3),
|
||||
}
|
||||
|
||||
|
||||
@ -362,17 +356,12 @@ def _get_default_config_bwd(query) -> Tuple[int, int, int, int]:
|
||||
head_dim = query.get_size()[-1]
|
||||
dtype = query.get_dtype()
|
||||
|
||||
if dtype == torch.float32:
|
||||
return (16, 16, 4, 1)
|
||||
if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100
|
||||
return (32, 128, 4, 3)
|
||||
elif torch.cuda.get_device_capability() >= (8, 0): # A100
|
||||
if head_dim == 64:
|
||||
return (32, 128, 4, 3)
|
||||
elif head_dim == 128:
|
||||
return (64, 128, 8, 3)
|
||||
else:
|
||||
return (64, 64, 4, 2)
|
||||
if dtype == torch.float32:
|
||||
return (64, 64, 4, 1)
|
||||
return (128, 128, 4, 3)
|
||||
elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0): # A100
|
||||
return (64, 64, 4, 1)
|
||||
else: # modest hardware or extremely large head_dim
|
||||
return (16, 16, 4, 1)
|
||||
|
||||
@ -400,7 +389,7 @@ def flex_attention(*args, **kwargs):
|
||||
query.get_device(),
|
||||
query.get_dtype(),
|
||||
query.get_size(),
|
||||
query.get_stride(),
|
||||
FlexibleLayout.contiguous_strides(query.get_size()),
|
||||
)
|
||||
# see NOTE:[TritonTemplates with multiple outputs]
|
||||
logsumexp_shape = query.get_size()[:-1] # [B, H, M]
|
||||
@ -572,7 +561,7 @@ flex_attention_backward_template = TritonTemplate(
|
||||
curr_n = start_n2
|
||||
num_steps = KV_LEN // BLOCK_N2
|
||||
for blk_idx in range(num_steps):
|
||||
offs_n2 = curr_n + tl.arange(0, BLOCK_N2)
|
||||
offs_n2= curr_n + tl.arange(0, BLOCK_N2)
|
||||
kT = tl.load(kT_ptrs)
|
||||
vT = tl.load(vT_ptrs)
|
||||
qk = tl.dot(q, kT)
|
||||
@ -702,8 +691,8 @@ flex_attention_backward_template = TritonTemplate(
|
||||
# Write back dK.
|
||||
index_n = offs_n1[:, None]
|
||||
index_k = offs_k[None, :]
|
||||
|
||||
mask = index_n <= KV_LEN
|
||||
# TODO generalize and add proper mask support
|
||||
mask = (index_n != -1) & (index_k != -1)
|
||||
{{store_output(("off_z", "off_h", "index_n", "index_k"), "dk", "mask", indent_width=8)}}
|
||||
""",
|
||||
)
|
||||
@ -756,7 +745,7 @@ def flex_attention_backward(*args, **kwargs):
|
||||
key.get_device(),
|
||||
key.get_dtype(),
|
||||
key.get_size(),
|
||||
key.get_stride(),
|
||||
FlexibleLayout.contiguous_strides(key.get_size()),
|
||||
)
|
||||
|
||||
# Create delta which will is needed for the bwd's kernel
|
||||
@ -764,23 +753,20 @@ def flex_attention_backward(*args, **kwargs):
|
||||
delta = lowerings[aten.sum](mul_delta, axis=-1)
|
||||
|
||||
# see NOTE:[TritonTemplates with multiple outputs]
|
||||
grad_query = empty_strided(
|
||||
query.get_size(), query.get_stride(), dtype=dtype, device=device
|
||||
)
|
||||
grad_value = empty_strided(
|
||||
value.get_size(), value.get_stride(), dtype=dtype, device=device
|
||||
)
|
||||
grad_query = full(
|
||||
query.get_size(), 0.0, dtype=dtype, device=device
|
||||
) # torch.zeros equivalent
|
||||
grad_query.realize()
|
||||
grad_value = empty_strided(value.get_size(), None, dtype=dtype, device=device)
|
||||
|
||||
choices: List[Any] = []
|
||||
configs: List[Tuple[int, int, int, int]] = []
|
||||
configs.append(_get_default_config_bwd(query))
|
||||
if config.max_autotune:
|
||||
for BLOCK1 in [32, 64]:
|
||||
for BLOCK2 in [32, 64, 128]:
|
||||
if BLOCK2 % BLOCK1 != 0:
|
||||
continue
|
||||
for BLOCK2 in [32, 64]:
|
||||
for w in [4, 8]:
|
||||
for s in [1, 3, 4, 5]:
|
||||
for s in [1, 3]:
|
||||
configs.append((BLOCK1, BLOCK2, w, s))
|
||||
|
||||
for BLOCK1, BLOCK2, num_warps, num_stages in configs:
|
||||
@ -804,9 +790,9 @@ def flex_attention_backward(*args, **kwargs):
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
BLOCK_M1=BLOCK1,
|
||||
BLOCK_N1=BLOCK2,
|
||||
BLOCK_N1=BLOCK1,
|
||||
BLOCK_M2=BLOCK2,
|
||||
BLOCK_N2=BLOCK1,
|
||||
BLOCK_N2=BLOCK2,
|
||||
BLOCK_DMODEL=query.get_size()[-1],
|
||||
# For now, we always assume the "sound" option
|
||||
SCORE_MOD_IS_LINEAR=False,
|
||||
|
@ -102,8 +102,6 @@ class PartialRender:
|
||||
return self.code
|
||||
|
||||
|
||||
# This is used to store info needed for lowering each subgraph in triton
|
||||
# templates
|
||||
SubgraphInfo = namedtuple(
|
||||
"SubgraphInfo",
|
||||
[
|
||||
@ -1578,11 +1576,8 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
for choice in top_k:
|
||||
result = timings[choice]
|
||||
if result:
|
||||
kernel_info = (
|
||||
choice.debug_extra if hasattr(choice, "debug_extra") else ""
|
||||
)
|
||||
sys.stderr.write(
|
||||
f" {choice.name} {result:.4f} ms {best_time / result:.1%} {kernel_info}\n"
|
||||
f" {choice.name} {result:.4f} ms {best_time / result:.1%}\n"
|
||||
)
|
||||
else:
|
||||
sys.stderr.write(
|
||||
|
Reference in New Issue
Block a user