mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 09:09:30 +08:00
### Approach: Using the current function declaration **Constraint:** Q_Heads % KV_Heads == 0 **Major change:** - Added a new argument enable_gqa: bool to sdpa function call - It adds a meaning to the last third dimension. Sample use cases this would enable: LLama3 ``` # LLama3 8b call to SDPA query = torch.rand(batch, 32, seq_len_q, D) key = torch.rand(batch, 8, seq_len_kv, D) value = torch.rand(batch, 8, seq_len_kv, D) output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True) # Output Shape (batch, 32, seq_len_q, D) ``` ### Design Choice: - Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0 - The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms. - By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged. ### Benchmarks: - **sdpa.py: #130634** For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True | forward_time when enable_gqa=False | | ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- | | 1 | 32 | 8 | 2048 | 2048 | 2048 | 100.71 | 119.70 | | 8 | 32 | 8 | 2048 | 2048 | 2048 | 539.78 | 628.83 | | 16 | 32 | 8 | 2048 | 2048 | 2048 | 1056.81 | 1225.48 | | 32 | 32 | 8 | 2048 | 2048 | 2048 | 2099.54 | 2440.45 |  - **TorchTitan: https://github.com/pytorch/torchtitan/pull/458** Pull Request resolved: https://github.com/pytorch/pytorch/pull/128898 Approved by: https://github.com/drisspg
96 lines
3.2 KiB
Python
96 lines
3.2 KiB
Python
# mypy: ignore-errors
|
|
|
|
from inspect import getattr_static
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from torch._dynamo.symbolic_convert import InstructionTranslator
|
|
|
|
from ..bytecode_transformation import create_call_function
|
|
from ..exc import Unsupported
|
|
from .base import VariableTracker
|
|
|
|
|
|
class SDPAParamsVariable(VariableTracker):
|
|
"""Represents the c++ params struct for scaled dot product attention.
|
|
This is a read-only container."""
|
|
|
|
@staticmethod
|
|
def create(tx: "InstructionTranslator", value, source):
|
|
from torch.backends.cuda import SDPAParams
|
|
from ..source import AttrSource
|
|
from .builder import VariableBuilder
|
|
from .torch import TorchInGraphFunctionVariable
|
|
|
|
query_var = VariableBuilder(tx, AttrSource(source, "query"))(value.query)
|
|
key_var = VariableBuilder(tx, AttrSource(source, "key"))(value.key)
|
|
value_var = VariableBuilder(tx, AttrSource(source, "value"))(value.value)
|
|
attn_mask_var = VariableBuilder(tx, AttrSource(source, "attn_mask"))(
|
|
value.attn_mask
|
|
)
|
|
dropout_var = VariableBuilder(tx, AttrSource(source, "dropout"))(value.dropout)
|
|
is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))(
|
|
value.is_causal
|
|
)
|
|
enable_gqa_var = VariableBuilder(tx, AttrSource(source, "enable_gqa"))(
|
|
value.enable_gqa
|
|
)
|
|
param_vars = [
|
|
query_var,
|
|
key_var,
|
|
value_var,
|
|
attn_mask_var,
|
|
dropout_var,
|
|
is_causal_var,
|
|
enable_gqa_var,
|
|
]
|
|
return TorchInGraphFunctionVariable(SDPAParams).call_function(
|
|
tx, param_vars, {}
|
|
)
|
|
|
|
def __init__(self, proxy, param_vars, **kwargs):
|
|
self.proxy = proxy
|
|
self.param_vars = param_vars
|
|
super().__init__(**kwargs)
|
|
|
|
def reconstruct(self, codegen):
|
|
assert self.source is None
|
|
assert self.param_vars is not None
|
|
codegen.add_push_null(
|
|
lambda: codegen.load_import_from("torch._C", "_SDPAParams")
|
|
)
|
|
codegen.foreach(self.param_vars)
|
|
codegen.extend_output(create_call_function(len(self.param_vars), False))
|
|
|
|
def as_proxy(self):
|
|
return self.proxy
|
|
|
|
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
|
import torch._C
|
|
from ..source import AttrSource
|
|
from .builder import wrap_fx_proxy
|
|
from .misc import GetAttrVariable
|
|
|
|
try:
|
|
getattr_static(torch._C._SDPAParams, name)
|
|
except AttributeError:
|
|
# Using raise from is too verbose here
|
|
raise Unsupported(
|
|
f"Unsupported torch._C._SDPAParams attribute {name}"
|
|
) from None
|
|
|
|
proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name)
|
|
if self.source is not None:
|
|
return wrap_fx_proxy(
|
|
tx=tx, proxy=proxy, source=AttrSource(self.source, name)
|
|
)
|
|
else:
|
|
return wrap_fx_proxy(tx=tx, proxy=proxy)
|
|
|
|
@staticmethod
|
|
def is_sdpa_params(value):
|
|
from torch.backends.cuda import SDPAParams
|
|
|
|
return value is SDPAParams
|