Files
pytorch/torch/_dynamo/variables/sdpa.py
jainapurva d039b14207 Grouped Query Attention (#128898)
### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128898
Approved by: https://github.com/drisspg
2024-07-29 21:49:06 +00:00

96 lines
3.2 KiB
Python

# mypy: ignore-errors
from inspect import getattr_static
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from ..bytecode_transformation import create_call_function
from ..exc import Unsupported
from .base import VariableTracker
class SDPAParamsVariable(VariableTracker):
"""Represents the c++ params struct for scaled dot product attention.
This is a read-only container."""
@staticmethod
def create(tx: "InstructionTranslator", value, source):
from torch.backends.cuda import SDPAParams
from ..source import AttrSource
from .builder import VariableBuilder
from .torch import TorchInGraphFunctionVariable
query_var = VariableBuilder(tx, AttrSource(source, "query"))(value.query)
key_var = VariableBuilder(tx, AttrSource(source, "key"))(value.key)
value_var = VariableBuilder(tx, AttrSource(source, "value"))(value.value)
attn_mask_var = VariableBuilder(tx, AttrSource(source, "attn_mask"))(
value.attn_mask
)
dropout_var = VariableBuilder(tx, AttrSource(source, "dropout"))(value.dropout)
is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))(
value.is_causal
)
enable_gqa_var = VariableBuilder(tx, AttrSource(source, "enable_gqa"))(
value.enable_gqa
)
param_vars = [
query_var,
key_var,
value_var,
attn_mask_var,
dropout_var,
is_causal_var,
enable_gqa_var,
]
return TorchInGraphFunctionVariable(SDPAParams).call_function(
tx, param_vars, {}
)
def __init__(self, proxy, param_vars, **kwargs):
self.proxy = proxy
self.param_vars = param_vars
super().__init__(**kwargs)
def reconstruct(self, codegen):
assert self.source is None
assert self.param_vars is not None
codegen.add_push_null(
lambda: codegen.load_import_from("torch._C", "_SDPAParams")
)
codegen.foreach(self.param_vars)
codegen.extend_output(create_call_function(len(self.param_vars), False))
def as_proxy(self):
return self.proxy
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
import torch._C
from ..source import AttrSource
from .builder import wrap_fx_proxy
from .misc import GetAttrVariable
try:
getattr_static(torch._C._SDPAParams, name)
except AttributeError:
# Using raise from is too verbose here
raise Unsupported(
f"Unsupported torch._C._SDPAParams attribute {name}"
) from None
proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name)
if self.source is not None:
return wrap_fx_proxy(
tx=tx, proxy=proxy, source=AttrSource(self.source, name)
)
else:
return wrap_fx_proxy(tx=tx, proxy=proxy)
@staticmethod
def is_sdpa_params(value):
from torch.backends.cuda import SDPAParams
return value is SDPAParams