[Inductor] Flex attention supports dynamic shape (#125994)

## static shapes perf
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod   | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|----------------|
| Average |     0.692 |              |             |             |             |            |             |                |
| Max     |     0.855 |           16 |          16 |        4096 |        4096 |         64 | head_bias   | torch.bfloat16 |
| Min     |     0.419 |            8 |          16 |         512 |         512 |        256 | noop        | torch.bfloat16 |
```
## dynamic shapes perf
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------|
| Average |     0.670 |              |             |             |             |            |               |                |
| Max     |     0.864 |           16 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |
| Min     |     0.376 |            8 |          16 |         512 |         512 |        256 | relative_bias | torch.bfloat16 |
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125994
Approved by: https://github.com/Chillee
This commit is contained in:
Yanbo Liang
2024-05-15 04:43:24 +00:00
committed by PyTorch MergeBot
parent 1485621ccb
commit dfab69fdf1
6 changed files with 161 additions and 17 deletions

View File

@ -1,3 +1,4 @@
import argparse
import itertools
from collections import defaultdict
from dataclasses import asdict, dataclass
@ -98,7 +99,7 @@ def generate_inputs(
return query, key, value
def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
def run_single_experiment(config: ExperimentConfig, dynamic=False) -> ExperimentResults:
device = torch.device("cuda")
query, key, value = generate_inputs(
config.batch_size,
@ -113,7 +114,7 @@ def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
def eager_sdpa(query, key, value, _):
return F.scaled_dot_product_attention(query, key, value)
compiled_sdpa = torch.compile(_flex_attention)
compiled_sdpa = torch.compile(_flex_attention, dynamic=dynamic)
score_mod = config.score_mod
@ -242,16 +243,26 @@ def generate_experiment_configs() -> List[ExperimentConfig]:
return all_configs
def main():
def main(dynamic=False):
seed = 123
np.random.seed(seed)
torch.manual_seed(seed)
results = []
for config in tqdm(generate_experiment_configs()):
results.append(Experiment(config, run_single_experiment(config)))
results.append(
Experiment(config, run_single_experiment(config, dynamic=dynamic))
)
print_results(results)
if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
parser.add_argument(
"--dynamic",
action="store_true",
help="Runs a dynamic shapes version of compiled flex attention.",
)
args = parser.parse_args()
main(args.dynamic)

View File

@ -126,6 +126,19 @@ D = 64
class TestTemplatedSDPA(InductorTestCase):
def _check_equal(self, golden_out, ref_out, compiled_out, dtype):
compiled_error = (golden_out - compiled_out).abs().mean()
ref_error = (golden_out - ref_out).abs().mean()
# Note, it seems like we really are less accurate than the float32
# computation, likely due to the online softmax
if dtype == torch.float32:
fudge_factor = 10.0
else:
fudge_factor = 1.1
if compiled_error > ref_error * fudge_factor:
msg = f"Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
self.assertTrue(False, msg)
def run_test(
self,
score_mod: Callable,
@ -145,18 +158,114 @@ class TestTemplatedSDPA(InductorTestCase):
)
ref_out = sdpa_partial(q, k, v)
compiled_out = compiled_sdpa(q, k, v)
self._check_equal(golden_out, ref_out, compiled_out, dtype)
compiled_error = (golden_out - compiled_out).abs().mean()
ref_error = (golden_out - ref_out).abs().mean()
# Note, it seems like we really are less accurate than the float32
# computation, likely due to the online softmax
if dtype == torch.float32:
fudge_factor = 10.0
else:
fudge_factor = 1.1
if compiled_error > ref_error * fudge_factor:
msg = f"Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
self.assertTrue(False, msg)
def run_dynamic_test(
self,
score_mod: Callable,
dtype: torch.dtype = torch.float16,
B: int = B,
H: int = H,
S: int = S,
D: int = D,
):
sdpa_partial = create_attention(score_mod)
# The first eager batch, shape (B, H, S, D)
q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out1 = sdpa_partial(
q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64)
)
ref_out1 = sdpa_partial(q1, k1, v1)
# The second eager batch, shape (B * 2, H, S / 2, D)
B = int(B * 2)
S = int(S / 2)
q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out2 = sdpa_partial(
q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64)
)
ref_out2 = sdpa_partial(q2, k2, v2)
# Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing.
# We check dynamo counters["frames"]["ok"] to ensure there is no re-compilation.
torch._dynamo.reset()
# Compiling with dynamic shape in the first batch.
compiled_sdpa = torch.compile(sdpa_partial, dynamic=True)
compiled_out1 = compiled_sdpa(q1, k1, v1)
self._check_equal(golden_out1, ref_out1, compiled_out1, dtype)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
# No re-compilation, use the compiled dynamic shape version.
compiled_out2 = compiled_sdpa(q2, k2, v2)
self._check_equal(golden_out2, ref_out2, compiled_out2, dtype)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
def run_automatic_dynamic_test(
self,
score_mod: Callable,
dtype: torch.dtype = torch.float16,
B: int = B,
H: int = H,
S: int = S,
D: int = D,
):
sdpa_partial = create_attention(score_mod)
# The first eager batch, shape (B, H, S, D)
q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out1 = sdpa_partial(
q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64)
)
ref_out1 = sdpa_partial(q1, k1, v1)
# The second eager batch, shape (B * 2, H, S / 2, D)
B = int(B * 2)
S = int(S / 2)
q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out2 = sdpa_partial(
q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64)
)
ref_out2 = sdpa_partial(q2, k2, v2)
# The third eager batch, shape (B * 4, H, S / 4, D)
B = int(B * 2)
S = int(S / 2)
q3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out3 = sdpa_partial(
q3.to(torch.float64), k3.to(torch.float64), v3.to(torch.float64)
)
ref_out3 = sdpa_partial(q3, k3, v3)
# Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing.
# We check dynamo counters["frames"]["ok"] to ensure:
# 1, the first batch is compiled with static shape
# 2, the second batch is compiled with dynamic shape
# 3, no re-compilation in the third batch
torch._dynamo.reset()
# The first batch.
compiled_sdpa = torch.compile(sdpa_partial)
compiled_out1 = compiled_sdpa(q1, k1, v1)
self._check_equal(golden_out1, ref_out1, compiled_out1, dtype)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
# The second batch (automatic dynamic).
compiled_out2 = compiled_sdpa(q2, k2, v2)
self._check_equal(golden_out2, ref_out2, compiled_out2, dtype)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
# The third batch (no re-compilation).
compiled_out3 = compiled_sdpa(q3, k3, v3)
self._check_equal(golden_out3, ref_out3, compiled_out3, dtype)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
@ -164,6 +273,20 @@ class TestTemplatedSDPA(InductorTestCase):
def test_builtin_score_mods(self, dtype: torch.dtype, score_mod: Callable):
self.run_test(score_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("score_mod", test_score_mods)
def test_builtin_score_mods_dynamic(self, dtype: torch.dtype, score_mod: Callable):
self.run_dynamic_test(score_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("score_mod", test_score_mods)
def test_builtin_score_mods_automatic_dynamic(
self, dtype: torch.dtype, score_mod: Callable
):
self.run_automatic_dynamic_test(score_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_skip_odd_keys(self, dtype: torch.dtype):

View File

@ -628,3 +628,7 @@ def is_from_defaults(source: Source):
if isinstance(source, ChainedSource):
return is_from_defaults(source.base)
return False
def is_cell_contents(source: Source):
return isinstance(source, AttrSource) and source.member == "cell_contents"

View File

@ -58,6 +58,7 @@ from ..source import (
FloatTensorSource,
GetItemSource,
GradSource,
is_cell_contents,
is_constant_source,
is_from_defaults,
is_from_optimizer_source,
@ -1166,6 +1167,7 @@ class VariableBuilder:
# NN modules on the fly)
or self.source.guard_source().is_nn_module()
or is_from_defaults(self.source)
or is_cell_contents(self.source)
):
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value, source=self.source)

View File

@ -162,7 +162,7 @@ sdpa_template = TritonTemplate(
# TODO generalize and add proper mask support
mask = (idx_m != -1) & (idx_d != -1)
{{store_output(("idx_z", "idx_h", "idx_m", "idx_d"), "acc")}}
{{store_output(("idx_z", "idx_h", "idx_m", "idx_d"), "acc", "mask")}}
# TODO dont want to write this if we dont require grad
if OUTPUT_LOGSUMEXP:

View File

@ -83,6 +83,10 @@ def _flex_attention(
"""
if torch.compiler.is_dynamo_compiling():
# mark head_dim and dim always to be static
for x in [query, key, value]:
torch._dynamo.mark_static(x, 1)
torch._dynamo.mark_static(x, -1)
out, _ = flex_attention_hop(query, key, value, score_mod)
return out