Compare commits

...

1 Commits

Author SHA1 Message Date
9abdc9505f create varlenmetadata 2025-11-14 11:54:34 -08:00

View File

@ -14,7 +14,7 @@ import torch
log = logging.getLogger(__name__)
__all__ = ["varlen_attn", "AuxRequest"]
__all__ = ["varlen_attn", "AuxRequest", "VarlenMetadata"]
@lru_cache(maxsize=8)
@ -23,6 +23,18 @@ def _should_use_cudnn(device_index: int) -> bool:
return False
class VarlenMetadata(NamedTuple):
"""
Cumulative sequence positions for queries and keys/values.
"""
cu_seq_q: torch.Tensor
cu_seq_k: torch.Tensor
max_q: int
max_k: int
class AuxRequest(NamedTuple):
"""
Request which auxiliary outputs to compute from varlen_attn.