mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] Flex attention supports dynamic shape (#125994)
## static shapes perf ``` | Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | |---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|----------------| | Average | 0.692 | | | | | | | | | Max | 0.855 | 16 | 16 | 4096 | 4096 | 64 | head_bias | torch.bfloat16 | | Min | 0.419 | 8 | 16 | 512 | 512 | 256 | noop | torch.bfloat16 | ``` ## dynamic shapes perf ``` | Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | |---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------| | Average | 0.670 | | | | | | | | | Max | 0.864 | 16 | 16 | 4096 | 4096 | 64 | relative_bias | torch.bfloat16 | | Min | 0.376 | 8 | 16 | 512 | 512 | 256 | relative_bias | torch.bfloat16 | ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/125994 Approved by: https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
1485621ccb
commit
dfab69fdf1
@ -1,3 +1,4 @@
|
||||
import argparse
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from dataclasses import asdict, dataclass
|
||||
@ -98,7 +99,7 @@ def generate_inputs(
|
||||
return query, key, value
|
||||
|
||||
|
||||
def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
|
||||
def run_single_experiment(config: ExperimentConfig, dynamic=False) -> ExperimentResults:
|
||||
device = torch.device("cuda")
|
||||
query, key, value = generate_inputs(
|
||||
config.batch_size,
|
||||
@ -113,7 +114,7 @@ def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
|
||||
def eager_sdpa(query, key, value, _):
|
||||
return F.scaled_dot_product_attention(query, key, value)
|
||||
|
||||
compiled_sdpa = torch.compile(_flex_attention)
|
||||
compiled_sdpa = torch.compile(_flex_attention, dynamic=dynamic)
|
||||
|
||||
score_mod = config.score_mod
|
||||
|
||||
@ -242,16 +243,26 @@ def generate_experiment_configs() -> List[ExperimentConfig]:
|
||||
return all_configs
|
||||
|
||||
|
||||
def main():
|
||||
def main(dynamic=False):
|
||||
seed = 123
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
results = []
|
||||
for config in tqdm(generate_experiment_configs()):
|
||||
results.append(Experiment(config, run_single_experiment(config)))
|
||||
results.append(
|
||||
Experiment(config, run_single_experiment(config, dynamic=dynamic))
|
||||
)
|
||||
|
||||
print_results(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--dynamic",
|
||||
action="store_true",
|
||||
help="Runs a dynamic shapes version of compiled flex attention.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args.dynamic)
|
||||
|
@ -126,6 +126,19 @@ D = 64
|
||||
|
||||
|
||||
class TestTemplatedSDPA(InductorTestCase):
|
||||
def _check_equal(self, golden_out, ref_out, compiled_out, dtype):
|
||||
compiled_error = (golden_out - compiled_out).abs().mean()
|
||||
ref_error = (golden_out - ref_out).abs().mean()
|
||||
# Note, it seems like we really are less accurate than the float32
|
||||
# computation, likely due to the online softmax
|
||||
if dtype == torch.float32:
|
||||
fudge_factor = 10.0
|
||||
else:
|
||||
fudge_factor = 1.1
|
||||
if compiled_error > ref_error * fudge_factor:
|
||||
msg = f"Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
|
||||
self.assertTrue(False, msg)
|
||||
|
||||
def run_test(
|
||||
self,
|
||||
score_mod: Callable,
|
||||
@ -145,18 +158,114 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||
)
|
||||
ref_out = sdpa_partial(q, k, v)
|
||||
compiled_out = compiled_sdpa(q, k, v)
|
||||
self._check_equal(golden_out, ref_out, compiled_out, dtype)
|
||||
|
||||
compiled_error = (golden_out - compiled_out).abs().mean()
|
||||
ref_error = (golden_out - ref_out).abs().mean()
|
||||
# Note, it seems like we really are less accurate than the float32
|
||||
# computation, likely due to the online softmax
|
||||
if dtype == torch.float32:
|
||||
fudge_factor = 10.0
|
||||
else:
|
||||
fudge_factor = 1.1
|
||||
if compiled_error > ref_error * fudge_factor:
|
||||
msg = f"Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
|
||||
self.assertTrue(False, msg)
|
||||
def run_dynamic_test(
|
||||
self,
|
||||
score_mod: Callable,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
B: int = B,
|
||||
H: int = H,
|
||||
S: int = S,
|
||||
D: int = D,
|
||||
):
|
||||
sdpa_partial = create_attention(score_mod)
|
||||
# The first eager batch, shape (B, H, S, D)
|
||||
q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
golden_out1 = sdpa_partial(
|
||||
q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64)
|
||||
)
|
||||
ref_out1 = sdpa_partial(q1, k1, v1)
|
||||
|
||||
# The second eager batch, shape (B * 2, H, S / 2, D)
|
||||
B = int(B * 2)
|
||||
S = int(S / 2)
|
||||
q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
golden_out2 = sdpa_partial(
|
||||
q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64)
|
||||
)
|
||||
ref_out2 = sdpa_partial(q2, k2, v2)
|
||||
|
||||
# Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing.
|
||||
# We check dynamo counters["frames"]["ok"] to ensure there is no re-compilation.
|
||||
torch._dynamo.reset()
|
||||
# Compiling with dynamic shape in the first batch.
|
||||
compiled_sdpa = torch.compile(sdpa_partial, dynamic=True)
|
||||
compiled_out1 = compiled_sdpa(q1, k1, v1)
|
||||
self._check_equal(golden_out1, ref_out1, compiled_out1, dtype)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
|
||||
|
||||
# No re-compilation, use the compiled dynamic shape version.
|
||||
compiled_out2 = compiled_sdpa(q2, k2, v2)
|
||||
self._check_equal(golden_out2, ref_out2, compiled_out2, dtype)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
|
||||
|
||||
def run_automatic_dynamic_test(
|
||||
self,
|
||||
score_mod: Callable,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
B: int = B,
|
||||
H: int = H,
|
||||
S: int = S,
|
||||
D: int = D,
|
||||
):
|
||||
sdpa_partial = create_attention(score_mod)
|
||||
# The first eager batch, shape (B, H, S, D)
|
||||
q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
golden_out1 = sdpa_partial(
|
||||
q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64)
|
||||
)
|
||||
ref_out1 = sdpa_partial(q1, k1, v1)
|
||||
|
||||
# The second eager batch, shape (B * 2, H, S / 2, D)
|
||||
B = int(B * 2)
|
||||
S = int(S / 2)
|
||||
q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
golden_out2 = sdpa_partial(
|
||||
q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64)
|
||||
)
|
||||
ref_out2 = sdpa_partial(q2, k2, v2)
|
||||
|
||||
# The third eager batch, shape (B * 4, H, S / 4, D)
|
||||
B = int(B * 2)
|
||||
S = int(S / 2)
|
||||
q3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
k3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
v3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
golden_out3 = sdpa_partial(
|
||||
q3.to(torch.float64), k3.to(torch.float64), v3.to(torch.float64)
|
||||
)
|
||||
ref_out3 = sdpa_partial(q3, k3, v3)
|
||||
|
||||
# Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing.
|
||||
# We check dynamo counters["frames"]["ok"] to ensure:
|
||||
# 1, the first batch is compiled with static shape
|
||||
# 2, the second batch is compiled with dynamic shape
|
||||
# 3, no re-compilation in the third batch
|
||||
torch._dynamo.reset()
|
||||
# The first batch.
|
||||
compiled_sdpa = torch.compile(sdpa_partial)
|
||||
compiled_out1 = compiled_sdpa(q1, k1, v1)
|
||||
self._check_equal(golden_out1, ref_out1, compiled_out1, dtype)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
|
||||
|
||||
# The second batch (automatic dynamic).
|
||||
compiled_out2 = compiled_sdpa(q2, k2, v2)
|
||||
self._check_equal(golden_out2, ref_out2, compiled_out2, dtype)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
|
||||
|
||||
# The third batch (no re-compilation).
|
||||
compiled_out3 = compiled_sdpa(q3, k3, v3)
|
||||
self._check_equal(golden_out3, ref_out3, compiled_out3, dtype)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes)
|
||||
@ -164,6 +273,20 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||
def test_builtin_score_mods(self, dtype: torch.dtype, score_mod: Callable):
|
||||
self.run_test(score_mod, dtype)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes)
|
||||
@common_utils.parametrize("score_mod", test_score_mods)
|
||||
def test_builtin_score_mods_dynamic(self, dtype: torch.dtype, score_mod: Callable):
|
||||
self.run_dynamic_test(score_mod, dtype)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes)
|
||||
@common_utils.parametrize("score_mod", test_score_mods)
|
||||
def test_builtin_score_mods_automatic_dynamic(
|
||||
self, dtype: torch.dtype, score_mod: Callable
|
||||
):
|
||||
self.run_automatic_dynamic_test(score_mod, dtype)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes)
|
||||
def test_skip_odd_keys(self, dtype: torch.dtype):
|
||||
|
@ -628,3 +628,7 @@ def is_from_defaults(source: Source):
|
||||
if isinstance(source, ChainedSource):
|
||||
return is_from_defaults(source.base)
|
||||
return False
|
||||
|
||||
|
||||
def is_cell_contents(source: Source):
|
||||
return isinstance(source, AttrSource) and source.member == "cell_contents"
|
||||
|
@ -58,6 +58,7 @@ from ..source import (
|
||||
FloatTensorSource,
|
||||
GetItemSource,
|
||||
GradSource,
|
||||
is_cell_contents,
|
||||
is_constant_source,
|
||||
is_from_defaults,
|
||||
is_from_optimizer_source,
|
||||
@ -1166,6 +1167,7 @@ class VariableBuilder:
|
||||
# NN modules on the fly)
|
||||
or self.source.guard_source().is_nn_module()
|
||||
or is_from_defaults(self.source)
|
||||
or is_cell_contents(self.source)
|
||||
):
|
||||
self.install_guards(GuardBuilder.CONSTANT_MATCH)
|
||||
return ConstantVariable.create(value=value, source=self.source)
|
||||
|
@ -162,7 +162,7 @@ sdpa_template = TritonTemplate(
|
||||
|
||||
# TODO generalize and add proper mask support
|
||||
mask = (idx_m != -1) & (idx_d != -1)
|
||||
{{store_output(("idx_z", "idx_h", "idx_m", "idx_d"), "acc")}}
|
||||
{{store_output(("idx_z", "idx_h", "idx_m", "idx_d"), "acc", "mask")}}
|
||||
|
||||
# TODO dont want to write this if we dont require grad
|
||||
if OUTPUT_LOGSUMEXP:
|
||||
|
@ -83,6 +83,10 @@ def _flex_attention(
|
||||
"""
|
||||
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
# mark head_dim and dim always to be static
|
||||
for x in [query, key, value]:
|
||||
torch._dynamo.mark_static(x, 1)
|
||||
torch._dynamo.mark_static(x, -1)
|
||||
out, _ = flex_attention_hop(query, key, value, score_mod)
|
||||
return out
|
||||
|
||||
|
Reference in New Issue
Block a user